From 3c90dcdc9e4db0bcb9b68f2cbbb546186d559e7c Mon Sep 17 00:00:00 2001 From: yym68686 Date: Mon, 25 Dec 2023 23:32:32 +0800 Subject: [PATCH] 1. Refactored the plugin function file, cleaning up unused functions. 2. Update g4f version to 0.1.9.6, fixed bug: with g4f availability. 3. Update the readme file. --- README.md | 2 +- bot.py | 8 +- config.py | 2 +- requirements.txt | 2 +- test/test_gpt4free.py | 1 + test/test_langchain_search_old.py | 235 ++++++++++++++++++++ test/test_url.py | 7 +- utils/chatgpt2api.py | 2 +- utils/gpt4free.py | 30 ++- utils/{agent.py => plugins.py} | 345 ++++++------------------------ 10 files changed, 345 insertions(+), 289 deletions(-) create mode 100644 test/test_langchain_search_old.py rename utils/{agent.py => plugins.py} (64%) diff --git a/README.md b/README.md index 4a499160..ad26dc30 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ To develop plugins, please follow the steps outlined below: - Initially, you need to add the environment variable for the plugin in the `config.PLUGINS` dictionary located in the `config.py` file. The value can be customized to be either enabled or disabled by default. It is advisable to use uppercase letters for the entire environment variable. - Subsequently, append the function's name and description in the `utils/function_call.py` file. - Then, enhance the `ask_stream` function in the `utils/chatgpt2api.py` file with the function's processing logic. You can refer to the existing examples within the `ask_stream` method for guidance on how to write it. -- Following that, write the function, as mentioned in the `utils/function_call.py` file, in the `utils/agent.py` file. +- Following that, write the function, as mentioned in the `utils/function_call.py` file, in the `utils/plugins.py` file. - Next, in the `bot.py` file, augment the `update_first_buttons_message` function with buttons, enabling users to freely toggle plugins using the `info` command. - Lastly, don't forget to add the plugin's description in the plugins section of the README. diff --git a/bot.py b/bot.py index 14860d03..78f82734 100644 --- a/bot.py +++ b/bot.py @@ -8,7 +8,7 @@ from utils.chatgpt2api import Chatbot as GPT from utils.chatgpt2api import claudebot from telegram.constants import ChatAction -from utils.agent import docQA, get_doc_from_local, Document_extract, pdfQA, get_encode_image +from utils.plugins import Document_extract, get_encode_image from telegram import BotCommand, InlineKeyboardButton, InlineKeyboardMarkup, InlineQueryResultArticle, InputTextMessageContent from telegram.ext import CommandHandler, MessageHandler, ApplicationBuilder, filters, CallbackQueryHandler, Application, AIORateLimiter, InlineQueryHandler from config import WEB_HOOK, PORT, BOT_TOKEN @@ -74,6 +74,9 @@ async def command_bot(update, context, language=None, prompt=translator_prompt, if message: if "claude" in config.GPT_ENGINE and config.ClaudeAPI: robot = config.claudeBot + if not config.API or config.PLUGINS["USE_G4F"]: + import utils.gpt4free as gpt4free + robot = gpt4free if image_url: robot = config.GPT4visionbot title = "`🤖️ gpt-4-vision-preview`\n\n" @@ -124,9 +127,6 @@ async def getChatGPT(update, context, title, robot, message, chatid, messageid): ) messageid = message.message_id get_answer = robot.ask_stream - if not config.API or (config.PLUGINS["USE_G4F"] and not config.PLUGINS["SEARCH_USE_GPT"]): - import utils.gpt4free as gpt4free - get_answer = gpt4free.get_response try: for data in get_answer(text, convo_id=str(chatid), pass_history=config.PASS_HISTORY): diff --git a/config.py b/config.py index 8b93fdc2..4152c4fe 100644 --- a/config.py +++ b/config.py @@ -57,7 +57,7 @@ PLUGINS = { "SEARCH_USE_GPT": (os.environ.get('SEARCH_USE_GPT', "True") == "False") == False, - "USE_G4F": False, + "USE_G4F": (os.environ.get('USE_G4F', "False") == "False") == False, "DATE": True, "URL": True, "VERSION": True, diff --git a/requirements.txt b/requirements.txt index 37d85751..a1132c77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ duckduckgo-search==4.1.0 langchain==0.0.271 oauth2client==3.0.0 pdfminer.six -g4f==0.1.8.8 +g4f==0.1.9.6 # plugin pytz \ No newline at end of file diff --git a/test/test_gpt4free.py b/test/test_gpt4free.py index 61181a49..7f769f8f 100644 --- a/test/test_gpt4free.py +++ b/test/test_gpt4free.py @@ -15,6 +15,7 @@ def get_response(message, model="gpt-3.5-turbo"): if __name__ == "__main__": console = Console() message = r""" +李雪主是谁? """ answer = "" for result in get_response(message, "gpt-4"): diff --git a/test/test_langchain_search_old.py b/test/test_langchain_search_old.py new file mode 100644 index 00000000..d6a67740 --- /dev/null +++ b/test/test_langchain_search_old.py @@ -0,0 +1,235 @@ +import os +import re + +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import config + +from langchain.chat_models import ChatOpenAI + + +from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain + +from langchain.prompts.chat import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +) +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.vectorstores import Chroma +from langchain.text_splitter import CharacterTextSplitter + +from langchain.document_loaders import UnstructuredPDFLoader + +def getmd5(string): + import hashlib + md5_hash = hashlib.md5() + md5_hash.update(string.encode('utf-8')) + md5_hex = md5_hash.hexdigest() + return md5_hex + +from utils.sitemap import SitemapLoader +async def get_doc_from_sitemap(url): + # https://www.langchain.asia/modules/indexes/document_loaders/examples/sitemap#%E8%BF%87%E6%BB%A4%E7%AB%99%E7%82%B9%E5%9C%B0%E5%9B%BE-url- + sitemap_loader = SitemapLoader(web_path=url) + docs = await sitemap_loader.load() + return docs + +async def get_doc_from_local(docpath, doctype="md"): + from langchain.document_loaders import DirectoryLoader + # 加载文件夹中的所有txt类型的文件 + loader = DirectoryLoader(docpath, glob='**/*.' + doctype) + # 将数据转成 document 对象,每个文件会作为一个 document + documents = loader.load() + return documents + +system_template="""Use the following pieces of context to answer the users question. +If you don't know the answer, just say "Hmm..., I'm not sure.", don't try to make up an answer. +ALWAYS return a "Sources" part in your answer. +The "Sources" part should be a reference to the source of the document from which you got your answer. + +Example of your response should be: + +``` +The answer is foo + +Sources: +1. abc +2. xyz +``` +Begin! +---------------- +{summaries} +""" +messages = [ + SystemMessagePromptTemplate.from_template(system_template), + HumanMessagePromptTemplate.from_template("{question}") +] +prompt = ChatPromptTemplate.from_messages(messages) + +def get_chain(store, llm): + chain_type_kwargs = {"prompt": prompt} + chain = RetrievalQAWithSourcesChain.from_chain_type( + llm, + chain_type="stuff", + retriever=store.as_retriever(), + chain_type_kwargs=chain_type_kwargs, + reduce_k_below_max_tokens=True + ) + return chain + +async def docQA(docpath, query_message, persist_db_path="db", model = "gpt-3.5-turbo"): + chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=config.API) + embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=config.API) + + sitemap = "sitemap.xml" + match = re.match(r'^(https?|ftp)://[^\s/$.?#].[^\s]*$', docpath) + if match: + doc_method = get_doc_from_sitemap + docpath = os.path.join(docpath, sitemap) + else: + doc_method = get_doc_from_local + + persist_db_path = getmd5(docpath) + if not os.path.exists(persist_db_path): + documents = await doc_method(docpath) + # 初始化加载器 + text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=50) + # 持久化数据 + split_docs = text_splitter.split_documents(documents) + vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path) + vector_store.persist() + else: + # 加载数据 + vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings) + + # 创建问答对象 + qa = get_chain(vector_store, chatllm) + # qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True) + # 进行问答 + result = qa({"question": query_message}) + return result + + +def persist_emdedding_pdf(docurl, persist_db_path): + embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None)) + filename = get_doc_from_url(docurl) + docpath = os.getcwd() + "/" + filename + loader = UnstructuredPDFLoader(docpath) + documents = loader.load() + # 初始化加载器 + text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25) + # 切割加载的 document + split_docs = text_splitter.split_documents(documents) + vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path) + vector_store.persist() + os.remove(docpath) + return vector_store + +async def pdfQA(docurl, docpath, query_message, model="gpt-3.5-turbo"): + chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=os.environ.get('API', None)) + embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None)) + persist_db_path = getmd5(docpath) + if not os.path.exists(persist_db_path): + vector_store = persist_emdedding_pdf(docurl, persist_db_path) + else: + vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings) + qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True) + result = qa({"query": query_message}) + return result['result'] + + +def pdf_search(docurl, query_message, model="gpt-3.5-turbo"): + chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=os.environ.get('API', None)) + embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None)) + filename = get_doc_from_url(docurl) + docpath = os.getcwd() + "/" + filename + loader = UnstructuredPDFLoader(docpath) + try: + documents = loader.load() + except: + print("pdf load error! docpath:", docpath) + return "" + os.remove(docpath) + # 初始化加载器 + text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25) + # 切割加载的 document + split_docs = text_splitter.split_documents(documents) + vector_store = Chroma.from_documents(split_docs, embeddings) + # 创建问答对象 + qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True) + # 进行问答 + result = qa({"query": query_message}) + return result['result'] + +def summary_each_url(threads, chainllm, prompt): + summary_prompt = PromptTemplate( + input_variables=["web_summary", "question", "language"], + template=( + "You need to response the following question: {question}." + "Your task is answer the above question in {language} based on the Search results provided. Provide a detailed and in-depth response" + "If there is no relevant content in the search results, just answer None, do not make any explanations." + "Search results: {web_summary}." + ), + ) + summary_threads = [] + + for t in threads: + tmp = t.join() + print(tmp) + chain = LLMChain(llm=chainllm, prompt=summary_prompt) + chain_thread = ThreadWithReturnValue(target=chain.run, args=({"web_summary": tmp, "question": prompt, "language": config.LANGUAGE},)) + chain_thread.start() + summary_threads.append(chain_thread) + + url_result = "" + for t in summary_threads: + tmp = t.join() + print("summary", tmp) + if tmp != "None": + url_result += "\n\n" + tmp + return url_result + +def get_search_results(prompt: str, context_max_tokens: int): + + url_text_list = get_url_text_list(prompt) + useful_source_text = "\n\n".join(url_text_list) + # useful_source_text = summary_each_url(threads, chainllm, prompt) + + useful_source_text, search_tokens_len = cut_message(useful_source_text, context_max_tokens) + print("search tokens len", search_tokens_len, "\n\n") + + return useful_source_text + +from typing import Any +from langchain.schema.output import LLMResult +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +class ChainStreamHandler(StreamingStdOutCallbackHandler): + def __init__(self): + self.tokens = [] + # 记得结束后这里置true + self.finish = False + self.answer = "" + + def on_llm_new_token(self, token: str, **kwargs): + # print(token) + self.tokens.append(token) + # yield ''.join(self.tokens) + # print(''.join(self.tokens)) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + self.finish = 1 + + def on_llm_error(self, error: Exception, **kwargs: Any) -> None: + print(str(error)) + self.tokens.append(str(error)) + + def generate_tokens(self): + while not self.finish or self.tokens: + if self.tokens: + data = self.tokens.pop(0) + self.answer += data + yield data + else: + pass + return self.answer \ No newline at end of file diff --git a/test/test_url.py b/test/test_url.py index 7ea2fb11..e63c8b44 100644 --- a/test/test_url.py +++ b/test/test_url.py @@ -11,8 +11,11 @@ def extract_date(url): match = "1000/01/01" else: match = "1000/01/01" - return datetime.datetime.strptime(match, '%Y/%m/%d') - + try: + return datetime.datetime.strptime(match, '%Y/%m/%d') + except: + match = "1000/01/01" + return datetime.datetime.strptime(match, '%Y/%m/%d') # 提取日期并创建一个包含日期和URL的元组列表 date_url_pairs = [(extract_date(url), url) for url in urls] diff --git a/utils/chatgpt2api.py b/utils/chatgpt2api.py index dafa6a4d..377efa4e 100644 --- a/utils/chatgpt2api.py +++ b/utils/chatgpt2api.py @@ -13,7 +13,7 @@ from typing import Set import config -from utils.agent import * +from utils.plugins import * from utils.function_call import function_call_list def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]: diff --git a/utils/gpt4free.py b/utils/gpt4free.py index 863eb4b1..c1f5b7a7 100644 --- a/utils/gpt4free.py +++ b/utils/gpt4free.py @@ -1,10 +1,32 @@ import re import g4f +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import config -def get_response(message, **kwargs): +GPT_ENGINE_map = { + "gpt-3.5-turbo": "gpt-3.5-turbo", + "gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301": "gpt-3.5-turbo", + "gpt-3.5-turbo-0613": "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106": "gpt-3.5-turbo", + "gpt-3.5-turbo-16k-0613": "gpt-3.5-turbo-0613", + "gpt-4": "gpt-4", + "gpt-4-0314": "gpt-4", + "gpt-4-32k": "gpt-4-32k", + "gpt-4-32k-0314": "gpt-4", + "gpt-4-0613": "gpt-4-0613", + "gpt-4-32k-0613": "gpt-4-32k-0613", + "gpt-4-1106-preview": "gpt-4-turbo", + "gpt-4-vision-preview": "gpt-4", + "claude-2-web": "gpt-4", + "claude-2": "gpt-4", +} + +def ask_stream(message, **kwargs): response = g4f.ChatCompletion.create( - model=config.GPT_ENGINE, + model=GPT_ENGINE_map[config.GPT_ENGINE], messages=[{"role": "user", "content": message}], stream=True, ) @@ -22,8 +44,8 @@ def bing(response): if __name__ == "__main__": message = rf""" - +鲁迅和周树人为什么打架 """ answer = "" - for result in get_response(message, "gpt-4"): + for result in ask_stream(message, model="gpt-4"): print(result, end="") \ No newline at end of file diff --git a/utils/agent.py b/utils/plugins.py similarity index 64% rename from utils/agent.py rename to utils/plugins.py index 24b633de..296d76df 100644 --- a/utils/agent.py +++ b/utils/plugins.py @@ -15,197 +15,14 @@ import threading import urllib.parse -from typing import Any import time as record_time from bs4 import BeautifulSoup -from langchain.llms import OpenAI -from langchain.chains import LLMChain, RetrievalQA, RetrievalQAWithSourcesChain -from langchain.agents import AgentType, load_tools, initialize_agent, tool -from langchain.schema import HumanMessage -from langchain.schema.output import LLMResult -from langchain.callbacks.manager import CallbackManager + from langchain.prompts import PromptTemplate -from langchain.prompts.chat import ( - ChatPromptTemplate, - SystemMessagePromptTemplate, - HumanMessagePromptTemplate, -) from langchain.chat_models import ChatOpenAI -from langchain.memory import ConversationBufferWindowMemory, ConversationTokenBufferMemory -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.vectorstores import Chroma -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from langchain.text_splitter import CharacterTextSplitter -from langchain.tools import DuckDuckGoSearchRun, DuckDuckGoSearchResults, Tool -from langchain.utilities import WikipediaAPIWrapper -from utils.googlesearch import GoogleSearchAPIWrapper -from langchain.document_loaders import UnstructuredPDFLoader - -def getmd5(string): - import hashlib - md5_hash = hashlib.md5() - md5_hash.update(string.encode('utf-8')) - md5_hex = md5_hash.hexdigest() - return md5_hex - -from utils.sitemap import SitemapLoader -async def get_doc_from_sitemap(url): - # https://www.langchain.asia/modules/indexes/document_loaders/examples/sitemap#%E8%BF%87%E6%BB%A4%E7%AB%99%E7%82%B9%E5%9C%B0%E5%9B%BE-url- - sitemap_loader = SitemapLoader(web_path=url) - docs = await sitemap_loader.load() - return docs - -async def get_doc_from_local(docpath, doctype="md"): - from langchain.document_loaders import DirectoryLoader - # 加载文件夹中的所有txt类型的文件 - loader = DirectoryLoader(docpath, glob='**/*.' + doctype) - # 将数据转成 document 对象,每个文件会作为一个 document - documents = loader.load() - return documents - -system_template="""Use the following pieces of context to answer the users question. -If you don't know the answer, just say "Hmm..., I'm not sure.", don't try to make up an answer. -ALWAYS return a "Sources" part in your answer. -The "Sources" part should be a reference to the source of the document from which you got your answer. - -Example of your response should be: - -``` -The answer is foo - -Sources: -1. abc -2. xyz -``` -Begin! ----------------- -{summaries} -""" -messages = [ - SystemMessagePromptTemplate.from_template(system_template), - HumanMessagePromptTemplate.from_template("{question}") -] -prompt = ChatPromptTemplate.from_messages(messages) - -def get_chain(store, llm): - chain_type_kwargs = {"prompt": prompt} - chain = RetrievalQAWithSourcesChain.from_chain_type( - llm, - chain_type="stuff", - retriever=store.as_retriever(), - chain_type_kwargs=chain_type_kwargs, - reduce_k_below_max_tokens=True - ) - return chain - -async def docQA(docpath, query_message, persist_db_path="db", model = "gpt-3.5-turbo"): - chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=config.API) - embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=config.API) - - sitemap = "sitemap.xml" - match = re.match(r'^(https?|ftp)://[^\s/$.?#].[^\s]*$', docpath) - if match: - doc_method = get_doc_from_sitemap - docpath = os.path.join(docpath, sitemap) - else: - doc_method = get_doc_from_local - - persist_db_path = getmd5(docpath) - if not os.path.exists(persist_db_path): - documents = await doc_method(docpath) - # 初始化加载器 - text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=50) - # 持久化数据 - split_docs = text_splitter.split_documents(documents) - vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path) - vector_store.persist() - else: - # 加载数据 - vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings) - - # 创建问答对象 - qa = get_chain(vector_store, chatllm) - # qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True) - # 进行问答 - result = qa({"question": query_message}) - return result - -def get_doc_from_url(url): - filename = urllib.parse.unquote(url.split("/")[-1]) - response = requests.get(url, stream=True) - with open(filename, 'wb') as f: - for chunk in response.iter_content(chunk_size=1024): - f.write(chunk) - return filename - -def persist_emdedding_pdf(docurl, persist_db_path): - embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None)) - filename = get_doc_from_url(docurl) - docpath = os.getcwd() + "/" + filename - loader = UnstructuredPDFLoader(docpath) - documents = loader.load() - # 初始化加载器 - text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25) - # 切割加载的 document - split_docs = text_splitter.split_documents(documents) - vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path) - vector_store.persist() - os.remove(docpath) - return vector_store - -async def pdfQA(docurl, docpath, query_message, model="gpt-3.5-turbo"): - chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=os.environ.get('API', None)) - embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None)) - persist_db_path = getmd5(docpath) - if not os.path.exists(persist_db_path): - vector_store = persist_emdedding_pdf(docurl, persist_db_path) - else: - vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings) - qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True) - result = qa({"query": query_message}) - return result['result'] - -def pdf_search(docurl, query_message, model="gpt-3.5-turbo"): - chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=os.environ.get('API', None)) - embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None)) - filename = get_doc_from_url(docurl) - docpath = os.getcwd() + "/" + filename - loader = UnstructuredPDFLoader(docpath) - try: - documents = loader.load() - except: - print("pdf load error! docpath:", docpath) - return "" - os.remove(docpath) - # 初始化加载器 - text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25) - # 切割加载的 document - split_docs = text_splitter.split_documents(documents) - vector_store = Chroma.from_documents(split_docs, embeddings) - # 创建问答对象 - qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True) - # 进行问答 - result = qa({"query": query_message}) - return result['result'] - -def Document_extract(docurl): - filename = get_doc_from_url(docurl) - docpath = os.getcwd() + "/" + filename - if filename[-3:] == "pdf": - from pdfminer.high_level import extract_text - text = extract_text(docpath) - if filename[-3:] == "txt": - with open(docpath, 'r') as f: - text = f.read() - prompt = ( - "Here is the document, inside XML tags:" - "" - "{}" - "" - ).format(text) - os.remove(docpath) - return prompt +from langchain.tools import DuckDuckGoSearchResults +from langchain.chains import LLMChain from typing import Optional, List from langchain.llms.base import LLM @@ -227,36 +44,6 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: if min_stop > -1: out = out[:min_stop] return out - -class ChainStreamHandler(StreamingStdOutCallbackHandler): - def __init__(self): - self.tokens = [] - # 记得结束后这里置true - self.finish = False - self.answer = "" - - def on_llm_new_token(self, token: str, **kwargs): - # print(token) - self.tokens.append(token) - # yield ''.join(self.tokens) - # print(''.join(self.tokens)) - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - self.finish = 1 - - def on_llm_error(self, error: Exception, **kwargs: Any) -> None: - print(str(error)) - self.tokens.append(str(error)) - - def generate_tokens(self): - while not self.finish or self.tokens: - if self.tokens: - data = self.tokens.pop(0) - self.answer += data - yield data - else: - pass - return self.answer class ThreadWithReturnValue(threading.Thread): def run(self): @@ -320,6 +107,7 @@ def getddgsearchurl(result, numresults=4): # print("ddg urls", urls) return urls +from utils.googlesearch import GoogleSearchAPIWrapper def getgooglesearchurl(result, numresults=3): google_search = GoogleSearchAPIWrapper() urls = [] @@ -462,33 +250,15 @@ def concat_url(threads): url_result.append(tmp) return url_result -def summary_each_url(threads, chainllm): - summary_prompt = PromptTemplate( - input_variables=["web_summary", "question", "language"], - template=( - "You need to response the following question: {question}." - "Your task is answer the above question in {language} based on the Search results provided. Provide a detailed and in-depth response" - "If there is no relevant content in the search results, just answer None, do not make any explanations." - "Search results: {web_summary}." - ), - ) - summary_threads = [] - - for t in threads: - tmp = t.join() - print(tmp) - chain = LLMChain(llm=chainllm, prompt=summary_prompt) - chain_thread = ThreadWithReturnValue(target=chain.run, args=({"web_summary": tmp, "question": prompt, "language": config.LANGUAGE},)) - chain_thread.start() - summary_threads.append(chain_thread) - - url_result = "" - for t in summary_threads: - tmp = t.join() - print("summary", tmp) - if tmp != "None": - url_result += "\n\n" + tmp - return url_result +def cut_message(message: str, max_tokens: int): + tiktoken.get_encoding("cl100k_base") + encoding = tiktoken.encoding_for_model(config.GPT_ENGINE) + encode_text = encoding.encode(message) + if len(encode_text) > max_tokens: + encode_text = encode_text[:max_tokens] + message = encoding.decode(encode_text) + encode_text = encoding.encode(message) + return message, len(encode_text) def get_url_text_list(prompt): start_time = record_time.time() @@ -519,47 +289,18 @@ def get_url_text_list(prompt): return url_text_list -def get_text_token_len(text): - tiktoken.get_encoding("cl100k_base") - encoding = tiktoken.encoding_for_model(config.GPT_ENGINE) - encode_text = encoding.encode(text) - return len(encode_text) - -def cut_message(message: str, max_tokens: int): - tiktoken.get_encoding("cl100k_base") - encoding = tiktoken.encoding_for_model(config.GPT_ENGINE) - encode_text = encoding.encode(message) - if len(encode_text) > max_tokens: - encode_text = encode_text[:max_tokens] - message = encoding.decode(encode_text) - encode_text = encoding.encode(message) - return message, len(encode_text) - +# Plugins 搜索 def get_search_results(prompt: str, context_max_tokens: int): url_text_list = get_url_text_list(prompt) useful_source_text = "\n\n".join(url_text_list) - # useful_source_text = summary_each_url(threads, chainllm) useful_source_text, search_tokens_len = cut_message(useful_source_text, context_max_tokens) print("search tokens len", search_tokens_len, "\n\n") return useful_source_text -def check_json(json_data): - while True: - try: - json.loads(json_data) - break - except json.decoder.JSONDecodeError as e: - print("JSON error:", e) - print("JSON body", repr(json_data)) - if "Invalid control character" in str(e): - json_data = json_data.replace("\n", "\\n") - if "Unterminated string starting" in str(e): - json_data += '"}' - return json_data - +# Plugins 获取日期时间 def get_date_time_weekday(): import datetime import pytz @@ -569,7 +310,7 @@ def get_date_time_weekday(): weekday_str = ['星期一', '星期二', '星期三', '星期四', '星期五', '星期六', '星期日'][weekday] return "今天是:" + str(now.date()) + ",现在的时间是:" + str(now.time())[:-7] + "," + weekday_str -# 使用函数 +# Plugins 使用函数 def get_version_info(): import subprocess current_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -577,10 +318,21 @@ def get_version_info(): output = result.stdout.decode() return output + + +# 公用函数 def encode_image(image_path): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') +def get_doc_from_url(url): + filename = urllib.parse.unquote(url.split("/")[-1]) + response = requests.get(url, stream=True) + with open(filename, 'wb') as f: + for chunk in response.iter_content(chunk_size=1024): + f.write(chunk) + return filename + def get_encode_image(image_url): filename = get_doc_from_url(image_url) image_path = os.getcwd() + "/" + filename @@ -589,6 +341,44 @@ def get_encode_image(image_url): os.remove(image_path) return prompt +def get_text_token_len(text): + tiktoken.get_encoding("cl100k_base") + encoding = tiktoken.encoding_for_model(config.GPT_ENGINE) + encode_text = encoding.encode(text) + return len(encode_text) + +def Document_extract(docurl): + filename = get_doc_from_url(docurl) + docpath = os.getcwd() + "/" + filename + if filename[-3:] == "pdf": + from pdfminer.high_level import extract_text + text = extract_text(docpath) + if filename[-3:] == "txt": + with open(docpath, 'r') as f: + text = f.read() + prompt = ( + "Here is the document, inside XML tags:" + "" + "{}" + "" + ).format(text) + os.remove(docpath) + return prompt + +def check_json(json_data): + while True: + try: + json.loads(json_data) + break + except json.decoder.JSONDecodeError as e: + print("JSON error:", e) + print("JSON body", repr(json_data)) + if "Invalid control character" in str(e): + json_data = json_data.replace("\n", "\\n") + if "Unterminated string starting" in str(e): + json_data += '"}' + return json_data + if __name__ == "__main__": os.system("clear") print(get_date_time_weekday()) @@ -600,6 +390,11 @@ def get_encode_image(image_url): # # 搜索 # for i in search_web_and_summary("今天的微博热搜有哪些?"): + # for i in search_web_and_summary("给出清华铊中毒案时间线,并作出你的评论。"): + # for i in search_web_and_summary("红警hbk08是谁"): + # for i in search_web_and_summary("国务院 2024 放假安排"): + # for i in search_web_and_summary("中国最新公布的游戏政策,对游戏行业和其他相关行业有什么样的影响?"): + # for i in search_web_and_summary("今天上海的天气怎么样?"): # for i in search_web_and_summary("阿里云24核96G的云主机价格是多少"): # for i in search_web_and_summary("话说葬送的芙莉莲动漫是半年番还是季番?完结没?"): # for i in search_web_and_summary("周海媚事件进展"):