From 0958a96d17df21efe5b2317c8bb5743a0c96dafa Mon Sep 17 00:00:00 2001 From: AkashiCoin <55268546+AkashiCoin@users.noreply.github.com> Date: Wed, 22 Nov 2023 22:06:03 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=E6=94=AF=E6=8C=81=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E5=88=86=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot_plugin_chatgpt_plus/__init__.py | 39 ++- nonebot_plugin_chatgpt_plus/chatgpt.py | 415 +++++++++++++----------- 2 files changed, 264 insertions(+), 190 deletions(-) diff --git a/nonebot_plugin_chatgpt_plus/__init__.py b/nonebot_plugin_chatgpt_plus/__init__.py index e9e43d3..7b2b21d 100644 --- a/nonebot_plugin_chatgpt_plus/__init__.py +++ b/nonebot_plugin_chatgpt_plus/__init__.py @@ -98,7 +98,32 @@ def check_purview(event: MessageEvent) -> bool: parameterless=[cooldow_checker(config.chatgpt_cd_time), single_run_locker()] ) async def ai_chat(bot: Bot, event: MessageEvent, state: T_State) -> None: + img_url: str = "" + img_info: dict = {} + if event.reply: + img_url = event.reply.message + for seg in event.message: + if seg.type == "image": + img_url = seg.data["url"].strip() + if isinstance(img_url, Message): + for seg in img_url: + if seg.type == "image": + img_url = seg.data["url"].strip() + if isinstance(img_url, MessageSegment): + img_url = img_url.data["url"] lockers[event.user_id] = True + if img_url: + try: + img_info = await chat_bot.upload_image_url(url=img_url) + if not img_info: + await matcher.finish("图片上传失败", reply_message=True) + logger.debug(f"ChatGPT image upload success: {img_info}") + except Exception as e: + error = f"{type(e).__name__}: {e}" + logger.opt(exception=e).error(f"ChatGPT request failed: {error}") + await matcher.finish(f"图片上传失败\n错误信息: {error}", reply_message=True) + finally: + lockers[event.user_id] = False message = _command_arg(state) or event.get_message() text = message.extract_plain_text().strip() if start := _command_start(state): @@ -146,24 +171,24 @@ async def ai_chat(bot: Bot, event: MessageEvent, state: T_State) -> None: msg_id = msg_id.get("message_id") msg = await chat_bot( **cvst, played_name=played_name, model=model - ).get_chat_response(text) + ).get_chat_response(text, image_info=img_info) if ( msg == "token失效,请重新设置token" and chat_bot.session_token != config.chatgpt_session_token ): chat_bot.session_token = config.chatgpt_session_token - msg = await chat_bot(**cvst, played_name=played_name, model=model).get_chat_response( - text - ) + msg = await chat_bot( + **cvst, played_name=played_name, model=model + ).get_chat_response(text, image_info=img_info) elif msg == "会话不存在": if config.chatgpt_auto_refresh: has_title = False cvst["conversation_id"].append(None) cvst["parent_id"].append(chat_bot.id) await matcher.send("会话不存在,已自动刷新对话,等待响应...", reply_message=True) - msg = await chat_bot(**cvst, played_name=played_name, model=model).get_chat_response( - text - ) + msg = await chat_bot( + **cvst, played_name=played_name, model=model + ).get_chat_response(text, image_info=img_info) else: msg += ",请刷新会话" except Exception as e: diff --git a/nonebot_plugin_chatgpt_plus/chatgpt.py b/nonebot_plugin_chatgpt_plus/chatgpt.py index fe6a546..331f9ba 100644 --- a/nonebot_plugin_chatgpt_plus/chatgpt.py +++ b/nonebot_plugin_chatgpt_plus/chatgpt.py @@ -1,4 +1,5 @@ import asyncio +from io import BytesIO import uuid import httpx @@ -7,6 +8,7 @@ from urllib.parse import urljoin from nonebot.log import logger from nonebot.utils import escape_tag +from PIL import Image from .utils import convert_seconds @@ -62,6 +64,11 @@ def __init__( ) if self.api_url.startswith("https://chat.openai.com"): raise ValueError("无法使用官方API,请使用第三方API") + self.client = httpx.AsyncClient( + proxies=self.proxies, + timeout=self.timeout, + base_url=self.api_url, + ) def __call__( self, @@ -114,13 +121,13 @@ def get_played_info(self, name: str) -> Dict[str, Any]: else "", ], }, - "metadata": { - "timestamp_": "absolute" - }, + "metadata": {"timestamp_": "absolute"}, "weight": 100, } - def get_payload(self, prompt: str, is_continue: bool = False) -> Dict[str, Any]: + def get_payload( + self, prompt: str, is_continue: bool = False, image_info: dict = None + ) -> Dict[str, Any]: payload = { "action": "continue", "conversation_id": self.conversation_id, @@ -136,138 +143,144 @@ def get_payload(self, prompt: str, is_continue: bool = False) -> Dict[str, Any]: "id": self.id, "author": {"role": "user"}, "role": "user", - "content": {"content_type": "text", "parts": [prompt]}, - "metadata": { - "timestamp_": "absolute" - }, + "content": {"content_type": "multimodal_text", "parts": [prompt]}, + "metadata": {"timestamp_": "absolute"}, } ] + if image_info: + messages[0]["content"]["parts"].insert(0, image_info) if self.played_name: messages.insert(0, self.get_played_info(self.played_name)) payload["messages"] = messages payload["action"] = "next" + logger.debug(f"payload: {payload}") return payload - async def get_chat_response(self, prompt: str, is_continue: bool = False) -> str: + async def get_chat_response( + self, prompt: str, is_continue: bool = False, image_info: dict = None + ) -> str: if not self.authorization: await self.refresh_session() if not self.authorization: return "Token获取失败,请检查配置或API是否可用" - async with httpx.AsyncClient(proxies=self.proxies) as client: - async with client.stream( - "POST", - urljoin(self.api_url, "backend-api/conversation"), - headers=self.headers, - json=self.get_payload(prompt, is_continue=is_continue), - timeout=self.timeout, - ) as response: - if response.status_code == 429: - msg = "" + async with self.client.stream( + "POST", + "backend-api/conversation", + headers=self.headers, + json=self.get_payload( + prompt, is_continue=is_continue, image_info=image_info + ), + ) as response: + if response.status_code == 429: + msg = "" + _buffer = bytearray() + async for chunk in response.aiter_bytes(): + _buffer.extend(chunk) + resp: dict = json.loads(_buffer.decode()) + if detail := resp.get("detail"): + if isinstance(detail, str): + msg += "\n" + detail + if is_continue and detail.startswith( + "Only one message at a time." + ): + await asyncio.sleep(3) + logger.info("ChatGPT自动续写中...") + return await self.get_chat_response( + prompt="", is_continue=True + ) + elif seconds := detail.get("clears_in"): + msg = f"\n请在 {convert_seconds(seconds)} 后重试" + if not is_continue: + return "请求过多,请放慢速度" + msg + if response.status_code == 401: + return "token失效,请重新设置token" + elif response.status_code == 403: + return "API错误,请联系开发者修复" + elif response.status_code == 404: + return "会话不存在" + elif response.status_code >= 500: + return f"API内部错误,错误代码: {response.status_code}" + elif response.is_error: + if is_continue: + response = await self.get_conversasion_message_response( + self.conversation_id, self.parent_id + ) + else: _buffer = bytearray() async for chunk in response.aiter_bytes(): _buffer.extend(chunk) - resp: dict = json.loads(_buffer.decode()) - if detail := resp.get("detail"): - if isinstance(detail, str): - msg += "\n" + detail - if is_continue and detail.startswith("Only one message at a time."): - await asyncio.sleep(3) - logger.info("ChatGPT自动续写中...") - return await self.get_chat_response(prompt="", is_continue=True) - elif seconds := detail.get("clears_in"): - msg = f"\n请在 {convert_seconds(seconds)} 后重试" - if not is_continue: - return "请求过多,请放慢速度" + msg - if response.status_code == 401: - return "token失效,请重新设置token" - elif response.status_code == 403: - return "API错误,请联系开发者修复" - elif response.status_code == 404: - return "会话不存在" - elif response.status_code >= 500: - return f"API内部错误,错误代码: {response.status_code}" - elif response.is_error: - if is_continue: - response = await self.get_conversasion_message_response( - self.conversation_id, self.parent_id - ) - else: - _buffer = bytearray() - async for chunk in response.aiter_bytes(): - _buffer.extend(chunk) - resp_text = _buffer.decode() - logger.opt(colors=True).error( - f"非预期的响应内容: HTTP{response.status_code} {resp_text}" - ) - return f"ChatGPT 服务器返回了非预期的内容: HTTP{response.status_code}\n{resp_text[:256]}" - else: - data_list = [] - async for line in response.aiter_lines(): - if line.startswith("data:"): - data = line[6:] - if data.startswith("{"): - try: - data_list.append(json.loads(data)) - except Exception as e: - logger.warning(f"ChatGPT数据解析未知错误:{e}: {data}") - if not data_list: - return "ChatGPT 服务器未返回任何内容" - idx = -1 - while data_list[idx].get("error") or data_list[idx].get("is_completion"): - idx -= 1 - response = data_list[idx] - self.parent_id = response["message"]["id"] - self.conversation_id = response["conversation_id"] - not_complete = "" - if not response["message"].get("end_turn", True): - if self.auto_continue: - logger.info("ChatGPT自动续写中...") - await asyncio.sleep(3) - return await self.get_chat_response("", True) - else: - not_complete = "\nis_complete: False" - else: - if response["message"].get("end_turn"): - response = await self.get_conversasion_message_response( - self.conversation_id, self.parent_id - ) - if isinstance(response, str): - return response - msg = "".join([text for text in response["message"]["content"]["parts"] if isinstance(text, str)]) - images = [image["asset_pointer"] for image in response["message"]["content"]["parts"] if not isinstance(image, str)] - logger.info(response) - logger.info(msg) - logger.info(images) - if self.metadata: - msg += "\n---" - msg += ( - f"\nmodel_slug: {response['message']['metadata']['model_slug']}" + resp_text = _buffer.decode() + logger.opt(colors=True).error( + f"非预期的响应内容: HTTP{response.status_code} {resp_text}" ) - msg += not_complete - if is_continue: - msg += "\nauto_continue: True" - if images: - return { - "message": msg, - "images": images - } + return f"ChatGPT 服务器返回了非预期的内容: HTTP{response.status_code}\n{resp_text[:256]}" + else: + data_list = [] + async for line in response.aiter_lines(): + if line.startswith("data:"): + data = line[6:] + if data.startswith("{"): + try: + data_list.append(json.loads(data)) + except Exception as e: + logger.warning(f"ChatGPT数据解析未知错误:{e}: {data}") + if not data_list: + return "ChatGPT 服务器未返回任何内容" + idx = -1 + while data_list[idx].get("error") or data_list[idx].get( + "is_completion" + ): + idx -= 1 + response = data_list[idx] + self.parent_id = response["message"]["id"] + self.conversation_id = response["conversation_id"] + not_complete = "" + if not response["message"].get("end_turn", True): + if self.auto_continue: + logger.info("ChatGPT自动续写中...") + await asyncio.sleep(3) + return await self.get_chat_response("", True) else: - return msg + not_complete = "\nis_complete: False" + else: + if response["message"].get("end_turn"): + response = await self.get_conversasion_message_response( + self.conversation_id, self.parent_id + ) + if isinstance(response, str): + return response + msg = "".join( + [ + text + for text in response["message"]["content"]["parts"] + if isinstance(text, str) + ] + ) + images = [ + image["asset_pointer"] + for image in response["message"]["content"]["parts"] + if not isinstance(image, str) + ] + logger.debug(response) + logger.debug(msg) + logger.debug(images) + if self.metadata: + msg += "\n---" + msg += f"\nmodel_slug: {response['message']['metadata']['model_slug']}" + msg += not_complete + if is_continue: + msg += "\nauto_continue: True" + if images: + return {"message": msg, "images": images} + else: + return msg async def edit_title(self, title: str) -> bool: - async with httpx.AsyncClient( + response = await self.client.patch( + f"backend-api/conversation/{self.conversation_id}", headers=self.headers, - proxies=self.proxies, - timeout=self.timeout, - ) as client: - response = await client.patch( - urljoin( - self.api_url, "backend-api/conversation/" + self.conversation_id - ), - json={ - "title": title if title.startswith("group") else f"private_{title}" - }, - ) + json={"title": title if title.startswith("group") else f"private_{title}"}, + ) try: resp = response.json() if resp.get("success"): @@ -281,18 +294,11 @@ async def edit_title(self, title: str) -> bool: return f"编辑标题失败,{e}" async def gen_title(self) -> str: - async with httpx.AsyncClient( + response = await self.client.post( + "backend-api/conversation/gen_title/" + self.conversation_id, headers=self.headers, - proxies=self.proxies, - timeout=self.timeout, - ) as client: - response = await client.post( - urljoin( - self.api_url, - "backend-api/conversation/gen_title/" + self.conversation_id, - ), - json={"message_id": self.parent_id}, - ) + json={"message_id": self.parent_id}, + ) try: resp = response.json() if resp.get("title"): @@ -306,35 +312,79 @@ async def gen_title(self) -> str: return f"生成标题失败,{e}" async def get_conversasion(self, conversation_id: str): - async with httpx.AsyncClient( + response = await self.client.get( + f"backend-api/conversation/{conversation_id}", headers=self.headers + ) + return response.json() + + async def upload_image_url(self, url: str): + logger.info(f"获取图片: {url}") + file_resp = await self.client.get(url) + if file_resp.status_code != 200: + logger.error(f"获取图片失败: {file_resp.text}") + return False + file = file_resp.content + response = await self.client.post( + "backend-api/files", headers=self.headers, - proxies=self.proxies, - timeout=self.timeout, - ) as client: - response = await client.get( - urljoin(self.api_url, f"backend-api/conversation/{conversation_id}") + json={ + "file_name": "img.png", + "file_size": len(file), + "use_case": "multimodal", + }, + ) + if response.status_code == 200: + resp_json = response.json() + upload_url = resp_json["upload_url"] + file_id = resp_json["file_id"] + else: + logger.error(f"获取上传图片链接失败: {response.text}") + return False + response = await self.client.put( + upload_url, + data=file, + headers={ + "X-Ms-Blob-Type": "BlockBlob", + "X-Ms-Version": "2020-04-08", + }, + ) + if response.status_code == 201: + image = Image.open(BytesIO(file)) + img_info = { + "asset_pointer": f"file-service://{file_id}", + "size_bytes": len(file), + "width": image.width, + "height": image.height, + } + response = await self.client.post( + f"backend-api/files/{file_id}/uploaded", json={}, headers=self.headers ) - return response.json() + if response.status_code == 200: + return img_info + else: + logger.error( + f"完成上传图片失败: HTTP{response.status_code}{response.text}" + ) + return False + else: + logger.error(f"上传图片失败: HTTP{response.status_code}{response.text}") + return False async def get_image_url_with_id(self, image_id: str): - async with httpx.AsyncClient( + response = await self.client.get( + f"backend-api/files/{image_id}/download", headers=self.headers, - proxies=self.proxies, - timeout=self.timeout, - ) as client: - response = await client.get( - urljoin(self.api_url, f"backend-api/files/{image_id}/download") + ) + try: + if response.status_code == 200: + resp_json = response.json() + return resp_json["download_url"] + else: + return False + except Exception as e: + logger.opt(colors=True, exception=e).error( + f"获取图片失败: HTTP{response.status_code} {response.text}" ) - try: - if response.status_code == 200: - resp_json = response.json() - return resp_json["download_url"] - else: - return False - except Exception as e: - logger.opt(colors=True, exception=e).error( - f"获取图片失败: HTTP{response.status_code} {response.text}" - ) async def get_conversasion_message_response( self, conversation_id: str, message_id: str @@ -346,14 +396,22 @@ async def get_conversasion_message_response( if messages := conversation.get("mapping"): resp = messages[message_id] message = messages[resp["parent"]] - while message["message"]["author"]["role"] == "assistant" or message["message"]["author"]["role"] == "tool": - logger.info(message) + while ( + message["message"]["author"]["role"] == "assistant" + or message["message"]["author"]["role"] == "tool" + ): + logger.debug(message) content_type = message["message"]["content"]["content_type"] if message["message"]["author"]["role"] == "tool": if content_type == "multimodal_text": - resp["message"]["content"]["parts"].extend(message["message"]["content"]["parts"]) + resp["message"]["content"]["parts"].extend( + message["message"]["content"]["parts"] + ) elif content_type == "text": - resp["message"]["content"]["parts"].extend(message["message"]["content"]["parts"]) + resp["message"]["content"]["parts"] = ( + message["message"]["content"]["parts"] + + resp["message"]["content"]["parts"] + ) message = messages[message["parent"]] resp["conversation_id"] = conversation_id return resp @@ -365,18 +423,13 @@ async def refresh_session(self) -> None: if self.auto_auth: await self.login() else: - cookies = { - SESSION_TOKEN_KEY: self.session_token, - } - async with httpx.AsyncClient( - cookies=cookies, - proxies=self.proxies, - timeout=self.timeout, - ) as client: - response = await client.get( - urljoin(self.api_url, "api/auth/session"), - headers={"User-Agent": self.user_agent}, - ) + response = await self.client.get( + urljoin(self.api_url, "api/auth/session"), + headers={"User-Agent": self.user_agent}, + cookies={ + SESSION_TOKEN_KEY: self.session_token, + }, + ) try: if response.status_code == 200: self.session_token = ( @@ -392,20 +445,16 @@ async def refresh_session(self) -> None: ) async def login(self) -> None: - async with httpx.AsyncClient( - proxies=self.proxies, - timeout=self.timeout, - ) as client: - response = await client.post( - "https://chat.loli.vet/api/auth/login", - headers={"User-Agent": self.user_agent}, - files={"username": self.account, "password": self.password}, - ) - if response.status_code == 200: - session_token = response.cookies.get(SESSION_TOKEN_KEY) - self.session_token = session_token - self.auto_auth = False - logger.opt(colors=True).info("ChatGPT 登录成功!") - await self.refresh_session() - else: - logger.error(f"ChatGPT 登陆错误! {response.text}") + response = await self.client.post( + "https://chat.loli.vet/api/auth/login", + headers={"User-Agent": self.user_agent}, + files={"username": self.account, "password": self.password}, + ) + if response.status_code == 200: + session_token = response.cookies.get(SESSION_TOKEN_KEY) + self.session_token = session_token + self.auto_auth = False + logger.opt(colors=True).info("ChatGPT 登录成功!") + await self.refresh_session() + else: + logger.error(f"ChatGPT 登陆错误! {response.text}")