Skip to content

Commit

Permalink
Fix broken functionalities (#17)
Browse files Browse the repository at this point in the history
* refine code

* fix bugs

* fix bugs

* update auth

* ignore .idea

* refine readme

* update requirements.txt

* update requirements

* update secrets

* use latest streamlit

* update secrets

* refine text

---------

Co-authored-by: Qin Liu <[email protected]>
  • Loading branch information
MochiXu and lqhl authored Jun 24, 2024
1 parent 6a65ab2 commit 4ee8f70
Show file tree
Hide file tree
Showing 62 changed files with 2,568 additions and 1,885 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ dist/
downloads/
eggs/
.eggs/
lib/
.idea/

lib64/
parts/
sdist/
Expand Down Expand Up @@ -163,5 +164,5 @@ cython_debug/
# dataset files
data/
.streamlit/
*.ipynb
#*.ipynb
.DS_Store
21 changes: 3 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ In conclusion, with ChatData, you can effortlessly navigate through vast amounts

➡️ Dive in and experience ChatData on [Hugging Face](https://huggingface.co/spaces/myscale/ChatData)🤗

![ChatData Homepage](assets/chatdata-homepage.png)
![ChatData Homepage](assets/home.png)

### Data schema

Expand Down Expand Up @@ -117,15 +117,6 @@ And for overall table schema, please refer to [table creation section in docs/se

If you want to use this database with `langchain.chains.sql_database.base.SQLDatabaseChain` or `langchain.retrievers.SQLDatabaseRetriever`, please follow guides on [data preparation section](docs/vector-sql.md#prepare-the-database) and [chain creation section](docs/vector-sql.md#create-the-sqldatabasechain) in docs/vector-sql.md

### How to run ChatData

<a name="how-to-run"></a>

```bash
python3 -m pip install requirements.txt
python3 -m streamlit run app.py
```

### Where can I get those arXiv data?

- [From parquet files on S3](docs/self-query.md#insert-data)
Expand Down Expand Up @@ -167,18 +158,12 @@ cd app/
2. Create an virtual environment

```bash
python3 -m venv .venv
source .venv/bin/activate
python3 -m venv venv
source venv/bin/activate
```

3. Install dependencies

> This app is currently using [MyScale's technical preview of LangChain](https://github.com/myscale/langchain/tree/preview).
>
>> It contains improved SQLDatabaseChain in [this PR](https://github.com/hwchase17/langchain/pull/7454)
>>
>> It contains [improved prompts](https://github.com/hwchase17/langchain/pull/6737#discussion_r1243527112) for comparators `LIKE` and `CONTAIN` in [MyScale self-query retriever](https://github.com/hwchase17/langchain/pull/6143).

```bash
python3 -m pip install -r requirements.txt
```
Expand Down
6 changes: 1 addition & 5 deletions app/.streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
[theme]
primaryColor="#523EFD"
backgroundColor="#FFFFFF"
secondaryBackgroundColor="#D4CEFF"
textColor="#262730"
font="sans serif"
base="dark"
3 changes: 2 additions & 1 deletion app/.streamlit/secrets.example.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
MYSCALE_HOST = "msc-4a9e710a.us-east-1.aws.staging.myscale.cloud" # read-only database provided by MyScale
MYSCALE_HOST = "msc-950b9f1f.us-east-1.aws.myscale.com" # read-only database provided by MyScale
MYSCALE_PORT = 443
MYSCALE_USER = "chatdata"
MYSCALE_PASSWORD = "myscale_rocks"
MYSCALE_ENABLE_HTTPS = true
OPENAI_API_BASE = "https://api.openai.com/v1"
OPENAI_API_KEY = "<your-openai-key>"
UNSTRUCTURED_API = "<your-unstructured-io-api>" # optional if you don't upload documents
Expand Down
195 changes: 74 additions & 121 deletions app/app.py
Original file line number Diff line number Diff line change
@@ -1,133 +1,86 @@
import pandas as pd
from os import environ
import os
import time

import streamlit as st

from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
ChatDataSQLAskCallBackHandler
from backend.constants.streamlit_keys import DATA_INITIALIZE_NOT_STATED, DATA_INITIALIZE_COMPLETED, \
DATA_INITIALIZE_STARTED
from backend.constants.variables import DATA_INITIALIZE_STATUS, JUMP_QUERY_ASK, CHAINS_RETRIEVERS_MAPPING, \
TABLE_EMBEDDINGS_MAPPING, RETRIEVER_TOOLS, USER_NAME, GLOBAL_CONFIG, update_global_config
from backend.construct.build_all import build_chains_and_retrievers, load_embedding_models, update_retriever_tools
from backend.types.global_config import GlobalConfig
from logger import logger
from ui.chat_page import chat_page
from ui.home import render_home
from ui.retrievers import render_retrievers

from chat import chat_page
from login import login, back_to_main
from lib.helper import build_tools, build_all, sel_map, display

# warnings.filterwarnings("ignore", category=UserWarning)

environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
def prepare_environment():
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
os.environ["LANGCHAIN_TRACING_V2"] = "false"
# os.environ["LANGCHAIN_API_KEY"] = ""
os.environ["OPENAI_API_BASE"] = st.secrets['OPENAI_API_BASE']
os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY']
os.environ["AUTH0_CLIENT_ID"] = st.secrets['AUTH0_CLIENT_ID']
os.environ["AUTH0_DOMAIN"] = st.secrets['AUTH0_DOMAIN']

update_global_config(GlobalConfig(
openai_api_base=st.secrets['OPENAI_API_BASE'],
openai_api_key=st.secrets['OPENAI_API_KEY'],
auth0_client_id=st.secrets['AUTH0_CLIENT_ID'],
auth0_domain=st.secrets['AUTH0_DOMAIN'],
myscale_user=st.secrets['MYSCALE_USER'],
myscale_password=st.secrets['MYSCALE_PASSWORD'],
myscale_host=st.secrets['MYSCALE_HOST'],
myscale_port=st.secrets['MYSCALE_PORT'],
query_model="gpt-3.5-turbo-0125",
chat_model="gpt-3.5-turbo-0125",
untrusted_api=st.secrets['UNSTRUCTURED_API'],
myscale_enable_https=st.secrets.get('MYSCALE_ENABLE_HTTPS', True),
))

st.set_page_config(page_title="ChatData",
page_icon="https://myscale.com/favicon.ico")
st.markdown(
f"""
<style>
.st-e4 {{
max-width: 500px
}}
</style>""",
unsafe_allow_html=True,
)
st.header("ChatData")

if 'sel_map_obj' not in st.session_state or 'embeddings' not in st.session_state:
st.session_state["sel_map_obj"], st.session_state["embeddings"] = build_all()
st.session_state["tools"] = build_tools()
# when refresh browser, all session keys will be cleaned.
def initialize_session_state():
if DATA_INITIALIZE_STATUS not in st.session_state:
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_NOT_STATED
logger.info(f"Initialize session state key: {DATA_INITIALIZE_STATUS}")
if JUMP_QUERY_ASK not in st.session_state:
st.session_state[JUMP_QUERY_ASK] = False
logger.info(f"Initialize session state key: {JUMP_QUERY_ASK}")

if login():
if "user_name" in st.session_state:
chat_page()
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:

sel = st.selectbox('Choose the knowledge base you want to ask with:',
options=['ArXiv Papers', 'Wikipedia'])
sel_map[sel]['hint']()
tab_sql, tab_self_query = st.tabs(
['Vector SQL', 'Self-Query Retrievers'])
with tab_sql:
sel_map[sel]['hint_sql']()
st.text_input("Ask a question:", key='query_sql')
cols = st.columns([1, 1, 1, 4])
cols[0].button("Query", key='search_sql')
cols[1].button("Ask", key='ask_sql')
cols[2].button("Back", key='back_sql', on_click=back_to_main)
plc_hldr = st.empty()
if st.session_state.search_sql:
plc_hldr = st.empty()
print(st.session_state.query_sql)
with plc_hldr.expander('Query Log', expanded=True):
callback = ChatDataSQLSearchCallBackHandler()
try:
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
st.session_state.query_sql, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(docs)
except Exception as e:
st.write('Oops 😵 Something bad happened...')
raise e
def initialize_chat_data():
if st.session_state[DATA_INITIALIZE_STATUS] != DATA_INITIALIZE_COMPLETED:
start_time = time.time()
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_STARTED
st.session_state[TABLE_EMBEDDINGS_MAPPING] = load_embedding_models()
st.session_state[CHAINS_RETRIEVERS_MAPPING] = build_chains_and_retrievers()
st.session_state[RETRIEVER_TOOLS] = update_retriever_tools()
# mark data initialization finished.
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_COMPLETED
end_time = time.time()
logger.info(f"ChatData initialized finished in {round(end_time - start_time, 3)} seconds, "
f"session state keys: {list(st.session_state.keys())}")

if st.session_state.ask_sql:
plc_hldr = st.empty()
print(st.session_state.query_sql)
with plc_hldr.expander('Chat Log', expanded=True):
callback = ChatDataSQLAskCallBackHandler()
try:
ret = st.session_state.sel_map_obj[sel]["sql_chain"](
st.session_state.query_sql, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
st.markdown(
f"### Answer from LLM\n{ret['answer']}\n### References")
docs = ret['sources']
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
except Exception as e:
st.write('Oops 😵 Something bad happened...')
raise e

with tab_self_query:
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='💡')
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
st.text_input("Ask a question:", key='query_self')
cols = st.columns([1, 1, 1, 4])
cols[0].button("Query", key='search_self')
cols[1].button("Ask", key='ask_self')
cols[2].button("Back", key='back_self', on_click=back_to_main)
plc_hldr = st.empty()
if st.session_state.search_self:
plc_hldr = st.empty()
print(st.session_state.query_self)
with plc_hldr.expander('Query Log', expanded=True):
call_back = None
callback = ChatDataSelfSearchCallBackHandler()
try:
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
st.session_state.query_self, callbacks=[callback])
print(docs)
callback.progress_bar.progress(value=1.0, text="Done!")
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(docs, sel_map[sel]["must_have_cols"])
except Exception as e:
st.write('Oops 😵 Something bad happened...')
raise e
st.set_page_config(
page_title="ChatData",
page_icon="https://myscale.com/favicon.ico",
initial_sidebar_state="expanded",
layout="wide",
)

prepare_environment()
initialize_session_state()
initialize_chat_data()

if st.session_state.ask_self:
plc_hldr = st.empty()
print(st.session_state.query_self)
with plc_hldr.expander('Chat Log', expanded=True):
call_back = None
callback = ChatDataSelfAskCallBackHandler()
try:
ret = st.session_state.sel_map_obj[sel]["chain"](
st.session_state.query_self, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
st.markdown(
f"### Answer from LLM\n{ret['answer']}\n### References")
docs = ret['sources']
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
except Exception as e:
st.write('Oops 😵 Something bad happened...')
raise e
if USER_NAME in st.session_state:
chat_page()
else:
if st.session_state[JUMP_QUERY_ASK]:
render_retrievers()
else:
render_home()
Empty file added app/backend/__init__.py
Empty file.
Empty file.
46 changes: 46 additions & 0 deletions app/backend/callbacks/arxiv_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
import textwrap
from typing import Dict, Any, List

from langchain.callbacks.streamlit.streamlit_callback_handler import (
LLMThought,
StreamlitCallbackHandler,
)


class LLMThoughtWithKnowledgeBase(LLMThought):
def on_tool_end(
self,
output: str,
color=None,
observation_prefix=None,
llm_prefix=None,
**kwargs: Any,
) -> None:
try:
self._container.markdown(
"\n\n".join(
["### Retrieved Documents:"]
+ [
f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
for i, r in enumerate(json.loads(output))
]
)
)
except Exception as e:
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)


class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
if self._current_thought is None:
self._current_thought = LLMThoughtWithKnowledgeBase(
parent_container=self._parent_container,
expanded=self._expand_new_thoughts,
collapse_on_complete=self._collapse_completed_thoughts,
labeler=self._thought_labeler,
)

self._current_thought.on_llm_start(serialized, prompts)
36 changes: 36 additions & 0 deletions app/backend/callbacks/llm_thought_with_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any, Dict, List

import streamlit as st
from langchain_core.outputs import LLMResult
from streamlit.external.langchain import StreamlitCallbackHandler


class ChatDataSelfQueryCallBack(StreamlitCallbackHandler):
def __init__(self):
super().__init__(st.container())
self._current_thought = None
self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery CallBack...")

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.progress_bar.progress(value=0.35, text="Communicate with LLM...")
pass

def on_chain_end(self, outputs, **kwargs) -> None:
if len(kwargs['tags']) == 0:
self.progress_bar.progress(value=0.75, text="Searching in DB...")

def on_chain_start(self, serialized, inputs, **kwargs) -> None:

pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
st.markdown("### Generate filter by LLM \n"
"> Here we get `query_constructor` results \n\n")

self.progress_bar.progress(value=0.5, text="Generate filter by LLM...")
for item in response.generations:
st.markdown(f"{item[0].text}")

pass
Loading

0 comments on commit 4ee8f70

Please sign in to comment.