-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* refine code * fix bugs * fix bugs * update auth * ignore .idea * refine readme * update requirements.txt * update requirements * update secrets * use latest streamlit * update secrets * refine text --------- Co-authored-by: Qin Liu <[email protected]>
- Loading branch information
Showing
62 changed files
with
2,568 additions
and
1,885 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,2 @@ | ||
[theme] | ||
primaryColor="#523EFD" | ||
backgroundColor="#FFFFFF" | ||
secondaryBackgroundColor="#D4CEFF" | ||
textColor="#262730" | ||
font="sans serif" | ||
base="dark" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,133 +1,86 @@ | ||
import pandas as pd | ||
from os import environ | ||
import os | ||
import time | ||
|
||
import streamlit as st | ||
|
||
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \ | ||
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \ | ||
ChatDataSQLAskCallBackHandler | ||
from backend.constants.streamlit_keys import DATA_INITIALIZE_NOT_STATED, DATA_INITIALIZE_COMPLETED, \ | ||
DATA_INITIALIZE_STARTED | ||
from backend.constants.variables import DATA_INITIALIZE_STATUS, JUMP_QUERY_ASK, CHAINS_RETRIEVERS_MAPPING, \ | ||
TABLE_EMBEDDINGS_MAPPING, RETRIEVER_TOOLS, USER_NAME, GLOBAL_CONFIG, update_global_config | ||
from backend.construct.build_all import build_chains_and_retrievers, load_embedding_models, update_retriever_tools | ||
from backend.types.global_config import GlobalConfig | ||
from logger import logger | ||
from ui.chat_page import chat_page | ||
from ui.home import render_home | ||
from ui.retrievers import render_retrievers | ||
|
||
from chat import chat_page | ||
from login import login, back_to_main | ||
from lib.helper import build_tools, build_all, sel_map, display | ||
|
||
# warnings.filterwarnings("ignore", category=UserWarning) | ||
|
||
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE'] | ||
def prepare_environment(): | ||
os.environ['TOKENIZERS_PARALLELISM'] = 'true' | ||
os.environ["LANGCHAIN_TRACING_V2"] = "false" | ||
# os.environ["LANGCHAIN_API_KEY"] = "" | ||
os.environ["OPENAI_API_BASE"] = st.secrets['OPENAI_API_BASE'] | ||
os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY'] | ||
os.environ["AUTH0_CLIENT_ID"] = st.secrets['AUTH0_CLIENT_ID'] | ||
os.environ["AUTH0_DOMAIN"] = st.secrets['AUTH0_DOMAIN'] | ||
|
||
update_global_config(GlobalConfig( | ||
openai_api_base=st.secrets['OPENAI_API_BASE'], | ||
openai_api_key=st.secrets['OPENAI_API_KEY'], | ||
auth0_client_id=st.secrets['AUTH0_CLIENT_ID'], | ||
auth0_domain=st.secrets['AUTH0_DOMAIN'], | ||
myscale_user=st.secrets['MYSCALE_USER'], | ||
myscale_password=st.secrets['MYSCALE_PASSWORD'], | ||
myscale_host=st.secrets['MYSCALE_HOST'], | ||
myscale_port=st.secrets['MYSCALE_PORT'], | ||
query_model="gpt-3.5-turbo-0125", | ||
chat_model="gpt-3.5-turbo-0125", | ||
untrusted_api=st.secrets['UNSTRUCTURED_API'], | ||
myscale_enable_https=st.secrets.get('MYSCALE_ENABLE_HTTPS', True), | ||
)) | ||
|
||
st.set_page_config(page_title="ChatData", | ||
page_icon="https://myscale.com/favicon.ico") | ||
st.markdown( | ||
f""" | ||
<style> | ||
.st-e4 {{ | ||
max-width: 500px | ||
}} | ||
</style>""", | ||
unsafe_allow_html=True, | ||
) | ||
st.header("ChatData") | ||
|
||
if 'sel_map_obj' not in st.session_state or 'embeddings' not in st.session_state: | ||
st.session_state["sel_map_obj"], st.session_state["embeddings"] = build_all() | ||
st.session_state["tools"] = build_tools() | ||
# when refresh browser, all session keys will be cleaned. | ||
def initialize_session_state(): | ||
if DATA_INITIALIZE_STATUS not in st.session_state: | ||
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_NOT_STATED | ||
logger.info(f"Initialize session state key: {DATA_INITIALIZE_STATUS}") | ||
if JUMP_QUERY_ASK not in st.session_state: | ||
st.session_state[JUMP_QUERY_ASK] = False | ||
logger.info(f"Initialize session state key: {JUMP_QUERY_ASK}") | ||
|
||
if login(): | ||
if "user_name" in st.session_state: | ||
chat_page() | ||
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask: | ||
|
||
sel = st.selectbox('Choose the knowledge base you want to ask with:', | ||
options=['ArXiv Papers', 'Wikipedia']) | ||
sel_map[sel]['hint']() | ||
tab_sql, tab_self_query = st.tabs( | ||
['Vector SQL', 'Self-Query Retrievers']) | ||
with tab_sql: | ||
sel_map[sel]['hint_sql']() | ||
st.text_input("Ask a question:", key='query_sql') | ||
cols = st.columns([1, 1, 1, 4]) | ||
cols[0].button("Query", key='search_sql') | ||
cols[1].button("Ask", key='ask_sql') | ||
cols[2].button("Back", key='back_sql', on_click=back_to_main) | ||
plc_hldr = st.empty() | ||
if st.session_state.search_sql: | ||
plc_hldr = st.empty() | ||
print(st.session_state.query_sql) | ||
with plc_hldr.expander('Query Log', expanded=True): | ||
callback = ChatDataSQLSearchCallBackHandler() | ||
try: | ||
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents( | ||
st.session_state.query_sql, callbacks=[callback]) | ||
callback.progress_bar.progress(value=1.0, text="Done!") | ||
docs = pd.DataFrame( | ||
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | ||
display(docs) | ||
except Exception as e: | ||
st.write('Oops 😵 Something bad happened...') | ||
raise e | ||
def initialize_chat_data(): | ||
if st.session_state[DATA_INITIALIZE_STATUS] != DATA_INITIALIZE_COMPLETED: | ||
start_time = time.time() | ||
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_STARTED | ||
st.session_state[TABLE_EMBEDDINGS_MAPPING] = load_embedding_models() | ||
st.session_state[CHAINS_RETRIEVERS_MAPPING] = build_chains_and_retrievers() | ||
st.session_state[RETRIEVER_TOOLS] = update_retriever_tools() | ||
# mark data initialization finished. | ||
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_COMPLETED | ||
end_time = time.time() | ||
logger.info(f"ChatData initialized finished in {round(end_time - start_time, 3)} seconds, " | ||
f"session state keys: {list(st.session_state.keys())}") | ||
|
||
if st.session_state.ask_sql: | ||
plc_hldr = st.empty() | ||
print(st.session_state.query_sql) | ||
with plc_hldr.expander('Chat Log', expanded=True): | ||
callback = ChatDataSQLAskCallBackHandler() | ||
try: | ||
ret = st.session_state.sel_map_obj[sel]["sql_chain"]( | ||
st.session_state.query_sql, callbacks=[callback]) | ||
callback.progress_bar.progress(value=1.0, text="Done!") | ||
st.markdown( | ||
f"### Answer from LLM\n{ret['answer']}\n### References") | ||
docs = ret['sources'] | ||
docs = pd.DataFrame( | ||
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | ||
display( | ||
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') | ||
except Exception as e: | ||
st.write('Oops 😵 Something bad happened...') | ||
raise e | ||
|
||
with tab_self_query: | ||
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡') | ||
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"]) | ||
st.text_input("Ask a question:", key='query_self') | ||
cols = st.columns([1, 1, 1, 4]) | ||
cols[0].button("Query", key='search_self') | ||
cols[1].button("Ask", key='ask_self') | ||
cols[2].button("Back", key='back_self', on_click=back_to_main) | ||
plc_hldr = st.empty() | ||
if st.session_state.search_self: | ||
plc_hldr = st.empty() | ||
print(st.session_state.query_self) | ||
with plc_hldr.expander('Query Log', expanded=True): | ||
call_back = None | ||
callback = ChatDataSelfSearchCallBackHandler() | ||
try: | ||
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents( | ||
st.session_state.query_self, callbacks=[callback]) | ||
print(docs) | ||
callback.progress_bar.progress(value=1.0, text="Done!") | ||
docs = pd.DataFrame( | ||
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | ||
display(docs, sel_map[sel]["must_have_cols"]) | ||
except Exception as e: | ||
st.write('Oops 😵 Something bad happened...') | ||
raise e | ||
st.set_page_config( | ||
page_title="ChatData", | ||
page_icon="https://myscale.com/favicon.ico", | ||
initial_sidebar_state="expanded", | ||
layout="wide", | ||
) | ||
|
||
prepare_environment() | ||
initialize_session_state() | ||
initialize_chat_data() | ||
|
||
if st.session_state.ask_self: | ||
plc_hldr = st.empty() | ||
print(st.session_state.query_self) | ||
with plc_hldr.expander('Chat Log', expanded=True): | ||
call_back = None | ||
callback = ChatDataSelfAskCallBackHandler() | ||
try: | ||
ret = st.session_state.sel_map_obj[sel]["chain"]( | ||
st.session_state.query_self, callbacks=[callback]) | ||
callback.progress_bar.progress(value=1.0, text="Done!") | ||
st.markdown( | ||
f"### Answer from LLM\n{ret['answer']}\n### References") | ||
docs = ret['sources'] | ||
docs = pd.DataFrame( | ||
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | ||
display( | ||
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') | ||
except Exception as e: | ||
st.write('Oops 😵 Something bad happened...') | ||
raise e | ||
if USER_NAME in st.session_state: | ||
chat_page() | ||
else: | ||
if st.session_state[JUMP_QUERY_ASK]: | ||
render_retrievers() | ||
else: | ||
render_home() |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import json | ||
import textwrap | ||
from typing import Dict, Any, List | ||
|
||
from langchain.callbacks.streamlit.streamlit_callback_handler import ( | ||
LLMThought, | ||
StreamlitCallbackHandler, | ||
) | ||
|
||
|
||
class LLMThoughtWithKnowledgeBase(LLMThought): | ||
def on_tool_end( | ||
self, | ||
output: str, | ||
color=None, | ||
observation_prefix=None, | ||
llm_prefix=None, | ||
**kwargs: Any, | ||
) -> None: | ||
try: | ||
self._container.markdown( | ||
"\n\n".join( | ||
["### Retrieved Documents:"] | ||
+ [ | ||
f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}" | ||
for i, r in enumerate(json.loads(output)) | ||
] | ||
) | ||
) | ||
except Exception as e: | ||
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs) | ||
|
||
|
||
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler): | ||
def on_llm_start( | ||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | ||
) -> None: | ||
if self._current_thought is None: | ||
self._current_thought = LLMThoughtWithKnowledgeBase( | ||
parent_container=self._parent_container, | ||
expanded=self._expand_new_thoughts, | ||
collapse_on_complete=self._collapse_completed_thoughts, | ||
labeler=self._thought_labeler, | ||
) | ||
|
||
self._current_thought.on_llm_start(serialized, prompts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Any, Dict, List | ||
|
||
import streamlit as st | ||
from langchain_core.outputs import LLMResult | ||
from streamlit.external.langchain import StreamlitCallbackHandler | ||
|
||
|
||
class ChatDataSelfQueryCallBack(StreamlitCallbackHandler): | ||
def __init__(self): | ||
super().__init__(st.container()) | ||
self._current_thought = None | ||
self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery CallBack...") | ||
|
||
def on_llm_start( | ||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | ||
) -> None: | ||
self.progress_bar.progress(value=0.35, text="Communicate with LLM...") | ||
pass | ||
|
||
def on_chain_end(self, outputs, **kwargs) -> None: | ||
if len(kwargs['tags']) == 0: | ||
self.progress_bar.progress(value=0.75, text="Searching in DB...") | ||
|
||
def on_chain_start(self, serialized, inputs, **kwargs) -> None: | ||
|
||
pass | ||
|
||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | ||
st.markdown("### Generate filter by LLM \n" | ||
"> Here we get `query_constructor` results \n\n") | ||
|
||
self.progress_bar.progress(value=0.5, text="Generate filter by LLM...") | ||
for item in response.generations: | ||
st.markdown(f"{item[0].text}") | ||
|
||
pass |
Oops, something went wrong.