From 0958a96d17df21efe5b2317c8bb5743a0c96dafa Mon Sep 17 00:00:00 2001
From: AkashiCoin <55268546+AkashiCoin@users.noreply.github.com>
Date: Wed, 22 Nov 2023 22:06:03 +0800
Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=E6=94=AF=E6=8C=81=E5=9B=BE=E7=89=87?=
=?UTF-8?q?=E5=88=86=E6=9E=90?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
nonebot_plugin_chatgpt_plus/__init__.py | 39 ++-
nonebot_plugin_chatgpt_plus/chatgpt.py | 415 +++++++++++++-----------
2 files changed, 264 insertions(+), 190 deletions(-)
diff --git a/nonebot_plugin_chatgpt_plus/__init__.py b/nonebot_plugin_chatgpt_plus/__init__.py
index e9e43d3..7b2b21d 100644
--- a/nonebot_plugin_chatgpt_plus/__init__.py
+++ b/nonebot_plugin_chatgpt_plus/__init__.py
@@ -98,7 +98,32 @@ def check_purview(event: MessageEvent) -> bool:
parameterless=[cooldow_checker(config.chatgpt_cd_time), single_run_locker()]
)
async def ai_chat(bot: Bot, event: MessageEvent, state: T_State) -> None:
+ img_url: str = ""
+ img_info: dict = {}
+ if event.reply:
+ img_url = event.reply.message
+ for seg in event.message:
+ if seg.type == "image":
+ img_url = seg.data["url"].strip()
+ if isinstance(img_url, Message):
+ for seg in img_url:
+ if seg.type == "image":
+ img_url = seg.data["url"].strip()
+ if isinstance(img_url, MessageSegment):
+ img_url = img_url.data["url"]
lockers[event.user_id] = True
+ if img_url:
+ try:
+ img_info = await chat_bot.upload_image_url(url=img_url)
+ if not img_info:
+ await matcher.finish("图片上传失败", reply_message=True)
+ logger.debug(f"ChatGPT image upload success: {img_info}")
+ except Exception as e:
+ error = f"{type(e).__name__}: {e}"
+ logger.opt(exception=e).error(f"ChatGPT request failed: {error}")
+ await matcher.finish(f"图片上传失败\n错误信息: {error}", reply_message=True)
+ finally:
+ lockers[event.user_id] = False
message = _command_arg(state) or event.get_message()
text = message.extract_plain_text().strip()
if start := _command_start(state):
@@ -146,24 +171,24 @@ async def ai_chat(bot: Bot, event: MessageEvent, state: T_State) -> None:
msg_id = msg_id.get("message_id")
msg = await chat_bot(
**cvst, played_name=played_name, model=model
- ).get_chat_response(text)
+ ).get_chat_response(text, image_info=img_info)
if (
msg == "token失效,请重新设置token"
and chat_bot.session_token != config.chatgpt_session_token
):
chat_bot.session_token = config.chatgpt_session_token
- msg = await chat_bot(**cvst, played_name=played_name, model=model).get_chat_response(
- text
- )
+ msg = await chat_bot(
+ **cvst, played_name=played_name, model=model
+ ).get_chat_response(text, image_info=img_info)
elif msg == "会话不存在":
if config.chatgpt_auto_refresh:
has_title = False
cvst["conversation_id"].append(None)
cvst["parent_id"].append(chat_bot.id)
await matcher.send("会话不存在,已自动刷新对话,等待响应...", reply_message=True)
- msg = await chat_bot(**cvst, played_name=played_name, model=model).get_chat_response(
- text
- )
+ msg = await chat_bot(
+ **cvst, played_name=played_name, model=model
+ ).get_chat_response(text, image_info=img_info)
else:
msg += ",请刷新会话"
except Exception as e:
diff --git a/nonebot_plugin_chatgpt_plus/chatgpt.py b/nonebot_plugin_chatgpt_plus/chatgpt.py
index fe6a546..331f9ba 100644
--- a/nonebot_plugin_chatgpt_plus/chatgpt.py
+++ b/nonebot_plugin_chatgpt_plus/chatgpt.py
@@ -1,4 +1,5 @@
import asyncio
+from io import BytesIO
import uuid
import httpx
@@ -7,6 +8,7 @@
from urllib.parse import urljoin
from nonebot.log import logger
from nonebot.utils import escape_tag
+from PIL import Image
from .utils import convert_seconds
@@ -62,6 +64,11 @@ def __init__(
)
if self.api_url.startswith("https://chat.openai.com"):
raise ValueError("无法使用官方API,请使用第三方API")
+ self.client = httpx.AsyncClient(
+ proxies=self.proxies,
+ timeout=self.timeout,
+ base_url=self.api_url,
+ )
def __call__(
self,
@@ -114,13 +121,13 @@ def get_played_info(self, name: str) -> Dict[str, Any]:
else "",
],
},
- "metadata": {
- "timestamp_": "absolute"
- },
+ "metadata": {"timestamp_": "absolute"},
"weight": 100,
}
- def get_payload(self, prompt: str, is_continue: bool = False) -> Dict[str, Any]:
+ def get_payload(
+ self, prompt: str, is_continue: bool = False, image_info: dict = None
+ ) -> Dict[str, Any]:
payload = {
"action": "continue",
"conversation_id": self.conversation_id,
@@ -136,138 +143,144 @@ def get_payload(self, prompt: str, is_continue: bool = False) -> Dict[str, Any]:
"id": self.id,
"author": {"role": "user"},
"role": "user",
- "content": {"content_type": "text", "parts": [prompt]},
- "metadata": {
- "timestamp_": "absolute"
- },
+ "content": {"content_type": "multimodal_text", "parts": [prompt]},
+ "metadata": {"timestamp_": "absolute"},
}
]
+ if image_info:
+ messages[0]["content"]["parts"].insert(0, image_info)
if self.played_name:
messages.insert(0, self.get_played_info(self.played_name))
payload["messages"] = messages
payload["action"] = "next"
+ logger.debug(f"payload: {payload}")
return payload
- async def get_chat_response(self, prompt: str, is_continue: bool = False) -> str:
+ async def get_chat_response(
+ self, prompt: str, is_continue: bool = False, image_info: dict = None
+ ) -> str:
if not self.authorization:
await self.refresh_session()
if not self.authorization:
return "Token获取失败,请检查配置或API是否可用"
- async with httpx.AsyncClient(proxies=self.proxies) as client:
- async with client.stream(
- "POST",
- urljoin(self.api_url, "backend-api/conversation"),
- headers=self.headers,
- json=self.get_payload(prompt, is_continue=is_continue),
- timeout=self.timeout,
- ) as response:
- if response.status_code == 429:
- msg = ""
+ async with self.client.stream(
+ "POST",
+ "backend-api/conversation",
+ headers=self.headers,
+ json=self.get_payload(
+ prompt, is_continue=is_continue, image_info=image_info
+ ),
+ ) as response:
+ if response.status_code == 429:
+ msg = ""
+ _buffer = bytearray()
+ async for chunk in response.aiter_bytes():
+ _buffer.extend(chunk)
+ resp: dict = json.loads(_buffer.decode())
+ if detail := resp.get("detail"):
+ if isinstance(detail, str):
+ msg += "\n" + detail
+ if is_continue and detail.startswith(
+ "Only one message at a time."
+ ):
+ await asyncio.sleep(3)
+ logger.info("ChatGPT自动续写中...")
+ return await self.get_chat_response(
+ prompt="", is_continue=True
+ )
+ elif seconds := detail.get("clears_in"):
+ msg = f"\n请在 {convert_seconds(seconds)} 后重试"
+ if not is_continue:
+ return "请求过多,请放慢速度" + msg
+ if response.status_code == 401:
+ return "token失效,请重新设置token"
+ elif response.status_code == 403:
+ return "API错误,请联系开发者修复"
+ elif response.status_code == 404:
+ return "会话不存在"
+ elif response.status_code >= 500:
+ return f"API内部错误,错误代码: {response.status_code}"
+ elif response.is_error:
+ if is_continue:
+ response = await self.get_conversasion_message_response(
+ self.conversation_id, self.parent_id
+ )
+ else:
_buffer = bytearray()
async for chunk in response.aiter_bytes():
_buffer.extend(chunk)
- resp: dict = json.loads(_buffer.decode())
- if detail := resp.get("detail"):
- if isinstance(detail, str):
- msg += "\n" + detail
- if is_continue and detail.startswith("Only one message at a time."):
- await asyncio.sleep(3)
- logger.info("ChatGPT自动续写中...")
- return await self.get_chat_response(prompt="", is_continue=True)
- elif seconds := detail.get("clears_in"):
- msg = f"\n请在 {convert_seconds(seconds)} 后重试"
- if not is_continue:
- return "请求过多,请放慢速度" + msg
- if response.status_code == 401:
- return "token失效,请重新设置token"
- elif response.status_code == 403:
- return "API错误,请联系开发者修复"
- elif response.status_code == 404:
- return "会话不存在"
- elif response.status_code >= 500:
- return f"API内部错误,错误代码: {response.status_code}"
- elif response.is_error:
- if is_continue:
- response = await self.get_conversasion_message_response(
- self.conversation_id, self.parent_id
- )
- else:
- _buffer = bytearray()
- async for chunk in response.aiter_bytes():
- _buffer.extend(chunk)
- resp_text = _buffer.decode()
- logger.opt(colors=True).error(
- f"非预期的响应内容: HTTP{response.status_code} {resp_text}"
- )
- return f"ChatGPT 服务器返回了非预期的内容: HTTP{response.status_code}\n{resp_text[:256]}"
- else:
- data_list = []
- async for line in response.aiter_lines():
- if line.startswith("data:"):
- data = line[6:]
- if data.startswith("{"):
- try:
- data_list.append(json.loads(data))
- except Exception as e:
- logger.warning(f"ChatGPT数据解析未知错误:{e}: {data}")
- if not data_list:
- return "ChatGPT 服务器未返回任何内容"
- idx = -1
- while data_list[idx].get("error") or data_list[idx].get("is_completion"):
- idx -= 1
- response = data_list[idx]
- self.parent_id = response["message"]["id"]
- self.conversation_id = response["conversation_id"]
- not_complete = ""
- if not response["message"].get("end_turn", True):
- if self.auto_continue:
- logger.info("ChatGPT自动续写中...")
- await asyncio.sleep(3)
- return await self.get_chat_response("", True)
- else:
- not_complete = "\nis_complete: False"
- else:
- if response["message"].get("end_turn"):
- response = await self.get_conversasion_message_response(
- self.conversation_id, self.parent_id
- )
- if isinstance(response, str):
- return response
- msg = "".join([text for text in response["message"]["content"]["parts"] if isinstance(text, str)])
- images = [image["asset_pointer"] for image in response["message"]["content"]["parts"] if not isinstance(image, str)]
- logger.info(response)
- logger.info(msg)
- logger.info(images)
- if self.metadata:
- msg += "\n---"
- msg += (
- f"\nmodel_slug: {response['message']['metadata']['model_slug']}"
+ resp_text = _buffer.decode()
+ logger.opt(colors=True).error(
+ f"非预期的响应内容: HTTP{response.status_code} {resp_text}"
)
- msg += not_complete
- if is_continue:
- msg += "\nauto_continue: True"
- if images:
- return {
- "message": msg,
- "images": images
- }
+ return f"ChatGPT 服务器返回了非预期的内容: HTTP{response.status_code}\n{resp_text[:256]}"
+ else:
+ data_list = []
+ async for line in response.aiter_lines():
+ if line.startswith("data:"):
+ data = line[6:]
+ if data.startswith("{"):
+ try:
+ data_list.append(json.loads(data))
+ except Exception as e:
+ logger.warning(f"ChatGPT数据解析未知错误:{e}: {data}")
+ if not data_list:
+ return "ChatGPT 服务器未返回任何内容"
+ idx = -1
+ while data_list[idx].get("error") or data_list[idx].get(
+ "is_completion"
+ ):
+ idx -= 1
+ response = data_list[idx]
+ self.parent_id = response["message"]["id"]
+ self.conversation_id = response["conversation_id"]
+ not_complete = ""
+ if not response["message"].get("end_turn", True):
+ if self.auto_continue:
+ logger.info("ChatGPT自动续写中...")
+ await asyncio.sleep(3)
+ return await self.get_chat_response("", True)
else:
- return msg
+ not_complete = "\nis_complete: False"
+ else:
+ if response["message"].get("end_turn"):
+ response = await self.get_conversasion_message_response(
+ self.conversation_id, self.parent_id
+ )
+ if isinstance(response, str):
+ return response
+ msg = "".join(
+ [
+ text
+ for text in response["message"]["content"]["parts"]
+ if isinstance(text, str)
+ ]
+ )
+ images = [
+ image["asset_pointer"]
+ for image in response["message"]["content"]["parts"]
+ if not isinstance(image, str)
+ ]
+ logger.debug(response)
+ logger.debug(msg)
+ logger.debug(images)
+ if self.metadata:
+ msg += "\n---"
+ msg += f"\nmodel_slug: {response['message']['metadata']['model_slug']}"
+ msg += not_complete
+ if is_continue:
+ msg += "\nauto_continue: True"
+ if images:
+ return {"message": msg, "images": images}
+ else:
+ return msg
async def edit_title(self, title: str) -> bool:
- async with httpx.AsyncClient(
+ response = await self.client.patch(
+ f"backend-api/conversation/{self.conversation_id}",
headers=self.headers,
- proxies=self.proxies,
- timeout=self.timeout,
- ) as client:
- response = await client.patch(
- urljoin(
- self.api_url, "backend-api/conversation/" + self.conversation_id
- ),
- json={
- "title": title if title.startswith("group") else f"private_{title}"
- },
- )
+ json={"title": title if title.startswith("group") else f"private_{title}"},
+ )
try:
resp = response.json()
if resp.get("success"):
@@ -281,18 +294,11 @@ async def edit_title(self, title: str) -> bool:
return f"编辑标题失败,{e}"
async def gen_title(self) -> str:
- async with httpx.AsyncClient(
+ response = await self.client.post(
+ "backend-api/conversation/gen_title/" + self.conversation_id,
headers=self.headers,
- proxies=self.proxies,
- timeout=self.timeout,
- ) as client:
- response = await client.post(
- urljoin(
- self.api_url,
- "backend-api/conversation/gen_title/" + self.conversation_id,
- ),
- json={"message_id": self.parent_id},
- )
+ json={"message_id": self.parent_id},
+ )
try:
resp = response.json()
if resp.get("title"):
@@ -306,35 +312,79 @@ async def gen_title(self) -> str:
return f"生成标题失败,{e}"
async def get_conversasion(self, conversation_id: str):
- async with httpx.AsyncClient(
+ response = await self.client.get(
+ f"backend-api/conversation/{conversation_id}", headers=self.headers
+ )
+ return response.json()
+
+ async def upload_image_url(self, url: str):
+ logger.info(f"获取图片: {url}")
+ file_resp = await self.client.get(url)
+ if file_resp.status_code != 200:
+ logger.error(f"获取图片失败: {file_resp.text}")
+ return False
+ file = file_resp.content
+ response = await self.client.post(
+ "backend-api/files",
headers=self.headers,
- proxies=self.proxies,
- timeout=self.timeout,
- ) as client:
- response = await client.get(
- urljoin(self.api_url, f"backend-api/conversation/{conversation_id}")
+ json={
+ "file_name": "img.png",
+ "file_size": len(file),
+ "use_case": "multimodal",
+ },
+ )
+ if response.status_code == 200:
+ resp_json = response.json()
+ upload_url = resp_json["upload_url"]
+ file_id = resp_json["file_id"]
+ else:
+ logger.error(f"获取上传图片链接失败: {response.text}")
+ return False
+ response = await self.client.put(
+ upload_url,
+ data=file,
+ headers={
+ "X-Ms-Blob-Type": "BlockBlob",
+ "X-Ms-Version": "2020-04-08",
+ },
+ )
+ if response.status_code == 201:
+ image = Image.open(BytesIO(file))
+ img_info = {
+ "asset_pointer": f"file-service://{file_id}",
+ "size_bytes": len(file),
+ "width": image.width,
+ "height": image.height,
+ }
+ response = await self.client.post(
+ f"backend-api/files/{file_id}/uploaded", json={}, headers=self.headers
)
- return response.json()
+ if response.status_code == 200:
+ return img_info
+ else:
+ logger.error(
+ f"完成上传图片失败: HTTP{response.status_code}{response.text}"
+ )
+ return False
+ else:
+ logger.error(f"上传图片失败: HTTP{response.status_code}{response.text}")
+ return False
async def get_image_url_with_id(self, image_id: str):
- async with httpx.AsyncClient(
+ response = await self.client.get(
+ f"backend-api/files/{image_id}/download",
headers=self.headers,
- proxies=self.proxies,
- timeout=self.timeout,
- ) as client:
- response = await client.get(
- urljoin(self.api_url, f"backend-api/files/{image_id}/download")
+ )
+ try:
+ if response.status_code == 200:
+ resp_json = response.json()
+ return resp_json["download_url"]
+ else:
+ return False
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ f"获取图片失败: HTTP{response.status_code} {response.text}"
)
- try:
- if response.status_code == 200:
- resp_json = response.json()
- return resp_json["download_url"]
- else:
- return False
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- f"获取图片失败: HTTP{response.status_code} {response.text}"
- )
async def get_conversasion_message_response(
self, conversation_id: str, message_id: str
@@ -346,14 +396,22 @@ async def get_conversasion_message_response(
if messages := conversation.get("mapping"):
resp = messages[message_id]
message = messages[resp["parent"]]
- while message["message"]["author"]["role"] == "assistant" or message["message"]["author"]["role"] == "tool":
- logger.info(message)
+ while (
+ message["message"]["author"]["role"] == "assistant"
+ or message["message"]["author"]["role"] == "tool"
+ ):
+ logger.debug(message)
content_type = message["message"]["content"]["content_type"]
if message["message"]["author"]["role"] == "tool":
if content_type == "multimodal_text":
- resp["message"]["content"]["parts"].extend(message["message"]["content"]["parts"])
+ resp["message"]["content"]["parts"].extend(
+ message["message"]["content"]["parts"]
+ )
elif content_type == "text":
- resp["message"]["content"]["parts"].extend(message["message"]["content"]["parts"])
+ resp["message"]["content"]["parts"] = (
+ message["message"]["content"]["parts"]
+ + resp["message"]["content"]["parts"]
+ )
message = messages[message["parent"]]
resp["conversation_id"] = conversation_id
return resp
@@ -365,18 +423,13 @@ async def refresh_session(self) -> None:
if self.auto_auth:
await self.login()
else:
- cookies = {
- SESSION_TOKEN_KEY: self.session_token,
- }
- async with httpx.AsyncClient(
- cookies=cookies,
- proxies=self.proxies,
- timeout=self.timeout,
- ) as client:
- response = await client.get(
- urljoin(self.api_url, "api/auth/session"),
- headers={"User-Agent": self.user_agent},
- )
+ response = await self.client.get(
+ urljoin(self.api_url, "api/auth/session"),
+ headers={"User-Agent": self.user_agent},
+ cookies={
+ SESSION_TOKEN_KEY: self.session_token,
+ },
+ )
try:
if response.status_code == 200:
self.session_token = (
@@ -392,20 +445,16 @@ async def refresh_session(self) -> None:
)
async def login(self) -> None:
- async with httpx.AsyncClient(
- proxies=self.proxies,
- timeout=self.timeout,
- ) as client:
- response = await client.post(
- "https://chat.loli.vet/api/auth/login",
- headers={"User-Agent": self.user_agent},
- files={"username": self.account, "password": self.password},
- )
- if response.status_code == 200:
- session_token = response.cookies.get(SESSION_TOKEN_KEY)
- self.session_token = session_token
- self.auto_auth = False
- logger.opt(colors=True).info("ChatGPT 登录成功!")
- await self.refresh_session()
- else:
- logger.error(f"ChatGPT 登陆错误! {response.text}")
+ response = await self.client.post(
+ "https://chat.loli.vet/api/auth/login",
+ headers={"User-Agent": self.user_agent},
+ files={"username": self.account, "password": self.password},
+ )
+ if response.status_code == 200:
+ session_token = response.cookies.get(SESSION_TOKEN_KEY)
+ self.session_token = session_token
+ self.auto_auth = False
+ logger.opt(colors=True).info("ChatGPT 登录成功!")
+ await self.refresh_session()
+ else:
+ logger.error(f"ChatGPT 登陆错误! {response.text}")