Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain integration upgrade #1026

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added cookbook/__init__.py
Empty file.
40 changes: 23 additions & 17 deletions cookbook/langchain/trace_langchain_RAG_with_experiment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from datetime import datetime
from functools import lru_cache
from operator import itemgetter

from dotenv import load_dotenv
Expand Down Expand Up @@ -30,22 +31,27 @@
pinecone = PineconeClient(api_key=os.getenv("PINECONE_API_KEY"), environment=os.getenv("PINECONE_ENVIRONMENT"))


class DocumentRetriever:
def __init__(self, url: str):
api_loader = RecursiveUrlLoader(url)
raw_documents = api_loader.load()
@lru_cache()
def get_docs(url):
api_loader = RecursiveUrlLoader(url)
raw_documents = api_loader.load()

# Transformer
doc_transformer = Html2TextTransformer()
transformed = doc_transformer.transform_documents(raw_documents)
# Transformer
doc_transformer = Html2TextTransformer()
transformed = doc_transformer.transform_documents(raw_documents)

# Splitter
text_splitter = TokenTextSplitter(
model_name="gpt-3.5-turbo",
chunk_size=2000,
chunk_overlap=200,
)
return text_splitter.split_documents(transformed)

# Splitter
text_splitter = TokenTextSplitter(
model_name="gpt-3.5-turbo",
chunk_size=2000,
chunk_overlap=200,
)
documents = text_splitter.split_documents(transformed)

class DocumentRetriever:
def __init__(self, url: str):
documents = get_docs(url)

# Define vector store based
embeddings = OpenAIEmbeddings()
Expand All @@ -59,7 +65,7 @@ def get_retriever(self):
class DocumentationChain:
def __init__(self, url):
retriever = DocumentRetriever(url).get_retriever()
model = ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0)
model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
prompt = ChatPromptTemplate.from_messages(
[
(
Expand Down Expand Up @@ -126,7 +132,7 @@ def get_chain(self):
@trace(
eval_funcs=[
# these are factory functions that return the actual evaluation functions, so we need to call them
answer_matches_target_llm_grader_factory(),
answer_matches_target_llm_grader_factory(model="gpt-4o-mini"),
answer_context_faithfulness_binary_factory(),
answer_context_faithfulness_statement_level_factory(),
context_query_relevancy_factory(context_fields=["context"]),
Expand All @@ -142,7 +148,7 @@ def main(question: str) -> str:
# insert the context into the trace as an input so that it can be referenced in the evaluation functions
# context needs to be retrieved after the chain is invoked
trace_insert({"inputs": {"context": dc.get_context()}})
print(output)
# print(output)
return output


Expand Down
17 changes: 9 additions & 8 deletions cookbook/langchain/trace_langchain_rag_agents.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import os

from dotenv import load_dotenv
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent, create_retriever_tool
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_core.tools import create_retriever_tool
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter

from parea import Parea
from parea.utils.trace_integrations.langchain import PareaAILangchainTracer

load_dotenv()

p = Parea(api_key=os.getenv("PAREA_API_KEY"))

loader = TextLoader("../assets/data/state_of_the_union.txt")


Expand All @@ -38,7 +37,9 @@


def main():
result = agent_executor({"input": "what did the president say about kentaji brown jackson in the most recent state of the union?"}, callbacks=[PareaAILangchainTracer()])
result = agent_executor.invoke(
{"input": "what did the president say about kentaji brown jackson in the most recent state of the union?"}, config={"callbacks": [PareaAILangchainTracer()]}
)
print(result)


Expand Down
10 changes: 9 additions & 1 deletion cookbook/langchain/trace_langchain_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ async def amain():
)


def stream_main():
for chunk in llm.stream(
"what color is the sky?",
config={"callbacks": [handler]},
):
print(chunk.content, end=" ", flush=True)


if __name__ == "__main__":
print(main())
# print(main())
print(asyncio.run(amain()))
25 changes: 9 additions & 16 deletions parea/parea_logger.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
from typing import Any, Dict, Optional
from typing import Optional

import json
import logging
import os

from attrs import asdict, define, field
from cattrs import structure

from parea.api_client import HTTPClient
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
from parea.helpers import serialize_metadata_values
from parea.schemas.log import TraceIntegrations
from parea.schemas.models import CreateGetProjectResponseSchema, TraceLog, UpdateLog
from parea.utils.trace_integrations.langchain_utils import _dumps_json
from parea.utils.universal_encoder import json_dumps

logger = logging.getLogger()
Expand All @@ -34,7 +31,9 @@ def set_project_uuid(self, project_uuid: str, project_name: str) -> None:
self._project_uuid = project_uuid
self._project_name = project_name

def _get_project_uuid(self) -> str:
def get_project_uuid(self) -> Optional[str]:
if not self._project_uuid:
self._project_uuid = self._create_or_get_project(self._project_name or "default").uuid
try:
if not self._project_uuid:
self._project_uuid = self._create_or_get_project(self._project_name or "default").uuid
Expand All @@ -61,7 +60,7 @@ def update_log(self, data: UpdateLog) -> None:

def record_log(self, data: TraceLog) -> None:
data = serialize_metadata_values(data)
data.project_uuid = self._get_project_uuid()
data.project_uuid = self.get_project_uuid()
self._client.request(
"POST",
LOG_ENDPOINT,
Expand All @@ -83,26 +82,20 @@ def default_log(self, data: TraceLog) -> None:
data.target = json_dumps(data.target)
self.record_log(data)

def record_vendor_log(self, data: Dict[str, Any], vendor: TraceIntegrations) -> None:
data["project_uuid"] = self._get_project_uuid()
if experiment_uuid := os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None):
data["experiment_uuid"] = experiment_uuid
def record_vendor_log(self, data: bytes, vendor: TraceIntegrations) -> None:
self._client.add_integration("langchain")
self._client.request(
"POST",
VENDOR_LOG_ENDPOINT.format(vendor=vendor.value),
data=json.loads(_dumps_json(data)), # uuid is not serializable
data=json.loads(data),
)

async def arecord_vendor_log(self, data: Dict[str, Any], vendor: TraceIntegrations) -> None:
data["project_uuid"] = self._get_project_uuid()
if experiment_uuid := os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None):
data["experiment_uuid"] = experiment_uuid
async def arecord_vendor_log(self, data: bytes, vendor: TraceIntegrations) -> None:
self._client.add_integration("langchain")
await self._client.request_async(
"POST",
VENDOR_LOG_ENDPOINT.format(vendor=vendor.value),
data=json.loads(_dumps_json(data)), # uuid is not serializable
data=json.loads(data), # uuid is not serializable
)


Expand Down
96 changes: 45 additions & 51 deletions parea/utils/trace_integrations/langchain.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
from typing import Any, Dict, List, Optional, Union
from __future__ import annotations

from typing import Any, Dict, List, Optional

import logging
from uuid import UUID

from langchain_core.tracers import BaseTracer
from langchain_core.tracers.schemas import ChainRun, LLMRun, Run, ToolRun
from langchain_core.tracers import LangChainTracer
from langchain_core.tracers.schemas import Run

from parea import get_current_trace_id, get_root_trace_id
from parea.helpers import is_logging_disabled
from parea.parea_logger import parea_logger
from parea.schemas import UpdateTraceScenario
from parea.schemas.log import TraceIntegrations
from parea.utils.trace_utils import fill_trace_data, get_current_trace_id, get_root_trace_id
from parea.utils.trace_integrations.parea_langchain_client import PareaLangchainClient
from parea.utils.trace_utils import fill_trace_data, trace_data

logger = logging.getLogger()


class PareaAILangchainTracer(BaseTracer):
class PareaAILangchainTracer(LangChainTracer):
"""Base callback handler that can be used to handle callbacks from langchain."""

parent_trace_id: UUID
_parea_root_trace_id: str = None
_parea_parent_trace_id: str = None
_session_id: Optional[str] = None
_tags: List[str] = []
_metadata: Dict[str, Any] = {}
_end_user_identifier: Optional[str] = None
_deployment_id: Optional[str] = None
_log_sample_rate: Optional[float] = 1.0

def __init__(
self,
Expand All @@ -35,51 +31,49 @@ def __init__(
deployment_id: Optional[str] = None,
log_sample_rate: Optional[float] = 1.0,
**kwargs: Any,
):
) -> None:
"""Initialize the Parea tracer."""
super().__init__(**kwargs)
self._session_id = session_id
self._end_user_identifier = end_user_identifier
self._deployment_id = deployment_id
self._log_sample_rate = log_sample_rate
if tags:
self._tags = tags
if metadata:
self._metadata = metadata
self.client = PareaLangchainClient(session_id, tags, metadata, end_user_identifier, deployment_id, log_sample_rate)
self.is_streaming = False

def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
def _persist_run(self, run: Run) -> None:
if is_logging_disabled():
return
try:
self.parent_trace_id = run.id
# using .dict() since langchain Run class currently set to Pydantic v1
data = run.dict()
data["_parea_root_trace_id"] = self._parea_root_trace_id or None
data["_session_id"] = self._session_id
data["_tags"] = self._tags
data["_metadata"] = self._metadata
data["_end_user_identifier"] = self._end_user_identifier
data["_deployment_id"] = self._deployment_id
data["_log_sample_rate"] = self._log_sample_rate
# check if run has an attribute execution order
if (hasattr(run, "execution_order") and run.execution_order == 1) or run.parent_run_id is None:
data["_parea_parent_trace_id"] = self._parea_parent_trace_id or None
parea_logger.record_vendor_log(data, TraceIntegrations.LANGCHAIN)
self._set_parea_root_and_parent_trace_id(run)
if self.is_streaming:
self.client.stream_log(run)
else:
self.client.log()
except Exception as e:
logger.exception(f"Error occurred while logging langchain run: {e}", stack_info=True)

def get_parent_trace_id(self) -> UUID:
return self.parent_trace_id
logger.exception(f"Error persisting langchain run: {e}")

def _on_run_create(self, run: Run) -> None:
try:
self.client.create_run_trace(run)
except Exception as e:
logger.exception(f"Error creating langchain run: {e}")

def _on_run_update(self, run: Run) -> None:
try:
self.client.update_run_trace(run)
except Exception as e:
logger.exception(f"Error updating langchain run: {e}")

def on_llm_new_token(self, *args: Any, **kwargs: Any):
super().on_llm_new_token(*args, **kwargs)
self.is_streaming = True

def _set_parea_root_and_parent_trace_id(self, run) -> None:
self.parent_trace_id = run.id
if (hasattr(run, "execution_order") and run.execution_order == 1) or run.parent_run_id is None:
# need to check if any traces already exist
self._parea_root_trace_id = get_root_trace_id()
parea_root_trace_id = get_root_trace_id()
if parent_trace_id := get_current_trace_id():
self._parea_parent_trace_id = parent_trace_id
_experiment_uuid = trace_data.get()[parent_trace_id].experiment_uuid
fill_trace_data(str(run.id), {"parent_trace_id": parent_trace_id}, UpdateTraceScenario.LANGCHAIN_CHILD)
langchain_to_parea_root_data = {run.id: {"parent_trace_id": parent_trace_id, "root_trace_id": parea_root_trace_id, "experiment_uuid": _experiment_uuid}}
self.client.set_parea_root_and_parent_trace_id(langchain_to_parea_root_data)

def _on_llm_end(self, run: Run) -> None:
self._persist_run(run)

def _on_chain_end(self, run: Run) -> None:
self._persist_run(run)
def get_parent_trace_id(self) -> UUID:
return self.parent_trace_id
16 changes: 7 additions & 9 deletions parea/utils/trace_integrations/langchain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,12 @@ def _dumps_json_single(obj: Any, default: Optional[Callable[[Any], Any]] = None)

def _dumps_json(obj: Any, depth: int = 0, serialize_py: bool = True) -> bytes:
"""Serialize an object to a JSON formatted string.
Parameters
----------
obj : Any
The object to serialize.
default : Callable[[Any], Any] or None, default=None
The default function to use for serialization.
Args:
obj: The object to serialize.
depth: The current depth of the serialization.
serialize_py: Whether to serialize Python objects.
Returns:
-------
str
The JSON formatted string.
The serialized JSON formatted string.
"""
return _dumps_json_single(obj, functools.partial(_serialize_json, depth=depth, serialize_py=serialize_py))

Expand Down Expand Up @@ -198,8 +194,10 @@ def _middle_copy(val: T, memo: Dict[int, Any], max_depth: int = 4, _depth: int =

def deepish_copy(val: T) -> T:
"""Deep copy a value with a compromise for uncopyable objects.

Args:
val: The value to be deep copied.

Returns:
The deep copied value.
"""
Expand Down
Loading