diff --git a/python/.gitignore b/python/.gitignore index 1fcb1529f..e0ab99769 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1 +1,2 @@ out +profiles diff --git a/python/Makefile b/python/Makefile index 3f8bc2782..f7ca1f502 100644 --- a/python/Makefile +++ b/python/Makefile @@ -13,6 +13,15 @@ benchmark-fast: rm -f $(OUTPUT) poetry run python -m bench -o $(OUTPUT) --fast +PROFILE_NAME ?= output + +profile-background-thread: + mkdir -p profiles + poetry run python -m cProfile -o profiles/$(PROFILE_NAME).prof bench/create_run.py + +view-profile: + poetry run snakeviz profiles/${PROFILE_NAME}.prof + tests: env \ -u LANGCHAIN_PROJECT \ diff --git a/python/bench/create_run.py b/python/bench/create_run.py new file mode 100644 index 000000000..7114d887b --- /dev/null +++ b/python/bench/create_run.py @@ -0,0 +1,158 @@ +import logging +import statistics +import time +from queue import PriorityQueue +from typing import Dict +from unittest.mock import Mock +from uuid import uuid4 + +from langsmith._internal._background_thread import ( + _tracing_thread_drain_queue, + _tracing_thread_handle_batch, +) +from langsmith.client import Client + + +def create_large_json(length: int) -> Dict: + """Create a large JSON object for benchmarking purposes.""" + large_array = [ + { + "index": i, + "data": f"This is element number {i}", + "nested": {"id": i, "value": f"Nested value for element {i}"}, + } + for i in range(length) + ] + + return { + "name": "Huge JSON", + "description": "This is a very large JSON object for benchmarking purposes.", + "array": large_array, + "metadata": { + "created_at": "2024-10-22T19:00:00Z", + "author": "Python Program", + "version": 1.0, + }, + } + + +def create_run_data(run_id: str, json_size: int) -> Dict: + """Create a single run data object.""" + return { + "name": "Run Name", + "id": run_id, + "run_type": "chain", + "inputs": create_large_json(json_size), + "outputs": create_large_json(json_size), + "extra": {"extra_data": "value"}, + "trace_id": "trace_id", + "dotted_order": "1.1", + "tags": ["tag1", "tag2"], + "session_name": "Session Name", + } + + +def mock_session() -> Mock: + """Create a mock session object.""" + mock_session = Mock() + mock_response = Mock() + mock_response.status_code = 202 + mock_response.text = "Accepted" + mock_response.json.return_value = {"status": "success"} + mock_session.request.return_value = mock_response + return mock_session + + +def create_dummy_data(json_size, num_runs) -> list: + return [create_run_data(str(uuid4()), json_size) for i in range(num_runs)] + + +def create_runs(runs: list, client: Client) -> None: + for run in runs: + client.create_run(**run) + + +def process_queue(client: Client) -> None: + if client.tracing_queue is None: + raise ValueError("Tracing queue is None") + while next_batch := _tracing_thread_drain_queue( + client.tracing_queue, limit=100, block=False + ): + _tracing_thread_handle_batch( + client, client.tracing_queue, next_batch, use_multipart=True + ) + + +def benchmark_run_creation( + *, num_runs: int, json_size: int, samples: int, benchmark_thread: bool +) -> Dict: + """ + Benchmark run creation with specified parameters. + Returns timing statistics. + """ + timings = [] + + if benchmark_thread: + client = Client(session=mock_session(), api_key="xxx", auto_batch_tracing=False) + client.tracing_queue = PriorityQueue() + else: + client = Client(session=mock_session(), api_key="xxx") + + if client.tracing_queue is None: + raise ValueError("Tracing queue is None") + + for _ in range(samples): + runs = create_dummy_data(json_size, num_runs) + + start = time.perf_counter() + + create_runs(runs, client) + + # wait for client.tracing_queue to be empty + if benchmark_thread: + # reset the timer + start = time.perf_counter() + process_queue(client) + else: + client.tracing_queue.join() + + elapsed = time.perf_counter() - start + + del runs + + timings.append(elapsed) + + return { + "mean": statistics.mean(timings), + "median": statistics.median(timings), + "stdev": statistics.stdev(timings) if len(timings) > 1 else 0, + "min": min(timings), + "max": max(timings), + } + + +def test_benchmark_runs( + *, json_size: int, num_runs: int, samples: int, benchmark_thread: bool +): + """ + Run benchmarks with different combinations of parameters and report results. + """ + results = benchmark_run_creation( + num_runs=num_runs, + json_size=json_size, + samples=samples, + benchmark_thread=benchmark_thread, + ) + + print(f"\nBenchmark Results for {num_runs} runs with JSON size {json_size}:") + print(f"Mean time: {results['mean']:.4f} seconds") + print(f"Median time: {results['median']:.4f} seconds") + print(f"Std Dev: {results['stdev']:.4f} seconds") + print(f"Min time: {results['min']:.4f} seconds") + print(f"Max time: {results['max']:.4f} seconds") + print(f"Throughput: {num_runs / results['mean']:.2f} runs/second") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + test_benchmark_runs(json_size=5000, num_runs=1000, samples=1, benchmark_thread=True) diff --git a/python/langsmith/_internal/_background_thread.py b/python/langsmith/_internal/_background_thread.py index 3a468643f..b6aee1f4e 100644 --- a/python/langsmith/_internal/_background_thread.py +++ b/python/langsmith/_internal/_background_thread.py @@ -1,15 +1,16 @@ from __future__ import annotations +import functools import logging import sys import threading import weakref -from dataclasses import dataclass, field from queue import Empty, Queue from typing import ( TYPE_CHECKING, - Any, List, + Union, + cast, ) from langsmith import schemas as ls_schemas @@ -18,6 +19,11 @@ _AUTO_SCALE_UP_NTHREADS_LIMIT, _AUTO_SCALE_UP_QSIZE_TRIGGER, ) +from langsmith._internal._operations import ( + SerializedFeedbackOperation, + SerializedRunOperation, + combine_serialized_queue_operations, +) if TYPE_CHECKING: from langsmith.client import Client @@ -25,7 +31,7 @@ logger = logging.getLogger("langsmith.client") -@dataclass(order=True) +@functools.total_ordering class TracingQueueItem: """An item in the tracing queue. @@ -36,8 +42,29 @@ class TracingQueueItem: """ priority: str - action: str - item: Any = field(compare=False) + item: Union[SerializedRunOperation, SerializedFeedbackOperation] + + __slots__ = ("priority", "item") + + def __init__( + self, + priority: str, + item: Union[SerializedRunOperation, SerializedFeedbackOperation], + ) -> None: + self.priority = priority + self.item = item + + def __lt__(self, other: TracingQueueItem) -> bool: + return (self.priority, self.item.__class__) < ( + other.priority, + other.item.__class__, + ) + + def __eq__(self, other: object) -> bool: + return isinstance(other, TracingQueueItem) and ( + self.priority, + self.item.__class__, + ) == (other.priority, other.item.__class__) def _tracing_thread_drain_queue( @@ -67,16 +94,20 @@ def _tracing_thread_handle_batch( batch: List[TracingQueueItem], use_multipart: bool, ) -> None: - create = [it.item for it in batch if it.action == "create"] - update = [it.item for it in batch if it.action == "update"] - feedback = [it.item for it in batch if it.action == "feedback"] try: + ops = combine_serialized_queue_operations([item.item for item in batch]) if use_multipart: - client.multipart_ingest( - create=create, update=update, feedback=feedback, pre_sampled=True - ) + client._multipart_ingest_ops(ops) else: - client.batch_ingest_runs(create=create, update=update, pre_sampled=True) + if any(isinstance(op, SerializedFeedbackOperation) for op in ops): + logger.warn( + "Feedback operations are not supported in non-multipart mode" + ) + ops = [ + op for op in ops if not isinstance(op, SerializedFeedbackOperation) + ] + client._batch_ingest_run_ops(cast(List[SerializedRunOperation], ops)) + except Exception: logger.error("Error in tracing queue", exc_info=True) # exceptions are logged elsewhere, but we need to make sure the diff --git a/python/langsmith/_internal/_multipart.py b/python/langsmith/_internal/_multipart.py new file mode 100644 index 000000000..ca7c6e656 --- /dev/null +++ b/python/langsmith/_internal/_multipart.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import Dict, Iterable, Tuple + +MultipartPart = Tuple[str, Tuple[None, bytes, str, Dict[str, str]]] + + +class MultipartPartsAndContext: + parts: list[MultipartPart] + context: str + + __slots__ = ("parts", "context") + + def __init__(self, parts: list[MultipartPart], context: str) -> None: + self.parts = parts + self.context = context + + +def join_multipart_parts_and_context( + parts_and_contexts: Iterable[MultipartPartsAndContext], +) -> MultipartPartsAndContext: + acc_parts: list[MultipartPart] = [] + acc_context: list[str] = [] + for parts_and_context in parts_and_contexts: + acc_parts.extend(parts_and_context.parts) + acc_context.append(parts_and_context.context) + return MultipartPartsAndContext(acc_parts, "; ".join(acc_context)) diff --git a/python/langsmith/_internal/_operations.py b/python/langsmith/_internal/_operations.py new file mode 100644 index 000000000..1ba99a6db --- /dev/null +++ b/python/langsmith/_internal/_operations.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +import itertools +import uuid +from typing import Literal, Optional, Union, cast + +import orjson + +from langsmith import schemas as ls_schemas +from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext +from langsmith._internal._serde import dumps_json as _dumps_json + + +class SerializedRunOperation: + operation: Literal["post", "patch"] + id: uuid.UUID + trace_id: uuid.UUID + + # this is the whole object, minus the other fields which + # are popped (inputs/outputs/events/attachments) + _none: bytes + + inputs: Optional[bytes] + outputs: Optional[bytes] + events: Optional[bytes] + attachments: Optional[ls_schemas.Attachments] + + __slots__ = ( + "operation", + "id", + "trace_id", + "_none", + "inputs", + "outputs", + "events", + "attachments", + ) + + def __init__( + self, + operation: Literal["post", "patch"], + id: uuid.UUID, + trace_id: uuid.UUID, + _none: bytes, + inputs: Optional[bytes] = None, + outputs: Optional[bytes] = None, + events: Optional[bytes] = None, + attachments: Optional[ls_schemas.Attachments] = None, + ) -> None: + self.operation = operation + self.id = id + self.trace_id = trace_id + self._none = _none + self.inputs = inputs + self.outputs = outputs + self.events = events + self.attachments = attachments + + def __eq__(self, other: object) -> bool: + return isinstance(other, SerializedRunOperation) and ( + self.operation, + self.id, + self.trace_id, + self._none, + self.inputs, + self.outputs, + self.events, + self.attachments, + ) == ( + other.operation, + other.id, + other.trace_id, + other._none, + other.inputs, + other.outputs, + other.events, + other.attachments, + ) + + +class SerializedFeedbackOperation: + id: uuid.UUID + trace_id: uuid.UUID + feedback: bytes + + __slots__ = ("id", "trace_id", "feedback") + + def __init__(self, id: uuid.UUID, trace_id: uuid.UUID, feedback: bytes) -> None: + self.id = id + self.trace_id = trace_id + self.feedback = feedback + + def __eq__(self, other: object) -> bool: + return isinstance(other, SerializedFeedbackOperation) and ( + self.id, + self.trace_id, + self.feedback, + ) == (other.id, other.trace_id, other.feedback) + + +def serialize_feedback_dict( + feedback: Union[ls_schemas.FeedbackCreate, dict], +) -> SerializedFeedbackOperation: + if hasattr(feedback, "dict") and callable(getattr(feedback, "dict")): + feedback_create: dict = feedback.dict() # type: ignore + else: + feedback_create = cast(dict, feedback) + if "id" not in feedback_create: + feedback_create["id"] = uuid.uuid4() + elif isinstance(feedback_create["id"], str): + feedback_create["id"] = uuid.UUID(feedback_create["id"]) + if "trace_id" not in feedback_create: + feedback_create["trace_id"] = uuid.uuid4() + elif isinstance(feedback_create["trace_id"], str): + feedback_create["trace_id"] = uuid.UUID(feedback_create["trace_id"]) + + return SerializedFeedbackOperation( + id=feedback_create["id"], + trace_id=feedback_create["trace_id"], + feedback=_dumps_json(feedback_create), + ) + + +def serialize_run_dict( + operation: Literal["post", "patch"], payload: dict +) -> SerializedRunOperation: + inputs = payload.pop("inputs", None) + outputs = payload.pop("outputs", None) + events = payload.pop("events", None) + attachments = payload.pop("attachments", None) + return SerializedRunOperation( + operation=operation, + id=payload["id"], + trace_id=payload["trace_id"], + _none=_dumps_json(payload), + inputs=_dumps_json(inputs) if inputs is not None else None, + outputs=_dumps_json(outputs) if outputs is not None else None, + events=_dumps_json(events) if events is not None else None, + attachments=attachments if attachments is not None else None, + ) + + +def combine_serialized_queue_operations( + ops: list[Union[SerializedRunOperation, SerializedFeedbackOperation]], +) -> list[Union[SerializedRunOperation, SerializedFeedbackOperation]]: + create_ops_by_id = { + op.id: op + for op in ops + if isinstance(op, SerializedRunOperation) and op.operation == "post" + } + passthrough_ops: list[ + Union[SerializedRunOperation, SerializedFeedbackOperation] + ] = [] + for op in ops: + if isinstance(op, SerializedRunOperation): + if op.operation == "post": + continue + + # must be patch + + create_op = create_ops_by_id.get(op.id) + if create_op is None: + passthrough_ops.append(op) + continue + + if op._none is not None and op._none != create_op._none: + # TODO optimize this more - this would currently be slowest + # for large payloads + create_op_dict = orjson.loads(create_op._none) + op_dict = { + k: v for k, v in orjson.loads(op._none).items() if v is not None + } + create_op_dict.update(op_dict) + create_op._none = orjson.dumps(create_op_dict) + + if op.inputs is not None: + create_op.inputs = op.inputs + if op.outputs is not None: + create_op.outputs = op.outputs + if op.events is not None: + create_op.events = op.events + if op.attachments is not None: + if create_op.attachments is None: + create_op.attachments = {} + create_op.attachments.update(op.attachments) + else: + passthrough_ops.append(op) + return list(itertools.chain(create_ops_by_id.values(), passthrough_ops)) + + +def serialized_feedback_operation_to_multipart_parts_and_context( + op: SerializedFeedbackOperation, +) -> MultipartPartsAndContext: + return MultipartPartsAndContext( + [ + ( + f"feedback.{op.id}", + ( + None, + op.feedback, + "application/json", + {"Content-Length": str(len(op.feedback))}, + ), + ) + ], + f"trace={op.trace_id},id={op.id}", + ) + + +def serialized_run_operation_to_multipart_parts_and_context( + op: SerializedRunOperation, +) -> MultipartPartsAndContext: + acc_parts: list[MultipartPart] = [] + + # this is main object, minus inputs/outputs/events/attachments + acc_parts.append( + ( + f"{op.operation}.{op.id}", + ( + None, + op._none, + "application/json", + {"Content-Length": str(len(op._none))}, + ), + ) + ) + for key, value in ( + ("inputs", op.inputs), + ("outputs", op.outputs), + ("events", op.events), + ): + if value is None: + continue + valb = value + acc_parts.append( + ( + f"{op.operation}.{op.id}.{key}", + ( + None, + valb, + "application/json", + {"Content-Length": str(len(valb))}, + ), + ), + ) + if op.attachments: + for n, (content_type, valb) in op.attachments.items(): + acc_parts.append( + ( + f"attachment.{op.id}.{n}", + ( + None, + valb, + content_type, + {"Content-Length": str(len(valb))}, + ), + ) + ) + return MultipartPartsAndContext( + acc_parts, + f"trace={op.trace_id},id={op.id}", + ) diff --git a/python/langsmith/_internal/_serde.py b/python/langsmith/_internal/_serde.py index 69940bce0..55057920b 100644 --- a/python/langsmith/_internal/_serde.py +++ b/python/langsmith/_internal/_serde.py @@ -12,8 +12,6 @@ import uuid from typing import ( Any, - Callable, - Optional, ) import orjson @@ -124,13 +122,25 @@ def _elide_surrogates(s: bytes) -> bytes: return result -def _dumps_json_single( - obj: Any, default: Optional[Callable[[Any], Any]] = None -) -> bytes: +def dumps_json(obj: Any) -> bytes: + """Serialize an object to a JSON formatted string. + + Parameters + ---------- + obj : Any + The object to serialize. + default : Callable[[Any], Any] or None, default=None + The default function to use for serialization. + + Returns: + ------- + str + The JSON formatted string. + """ try: return orjson.dumps( obj, - default=default or _simple_default, + default=_serialize_json, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS | orjson.OPT_SERIALIZE_UUID @@ -151,21 +161,3 @@ def _dumps_json_single( except orjson.JSONDecodeError: result = _elide_surrogates(result) return result - - -def dumps_json(obj: Any, depth: int = 0) -> bytes: - """Serialize an object to a JSON formatted string. - - Parameters - ---------- - obj : Any - The object to serialize. - default : Callable[[Any], Any] or None, default=None - The default function to use for serialization. - - Returns: - ------- - str - The JSON formatted string. - """ - return _dumps_json_single(obj, _serialize_json) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index ec3af9ee4..a343bbfb2 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -21,6 +21,7 @@ import importlib import importlib.metadata import io +import itertools import json import logging import os @@ -78,6 +79,19 @@ _BLOCKSIZE_BYTES, _SIZE_LIMIT_BYTES, ) +from langsmith._internal._multipart import ( + MultipartPartsAndContext, + join_multipart_parts_and_context, +) +from langsmith._internal._operations import ( + SerializedFeedbackOperation, + SerializedRunOperation, + combine_serialized_queue_operations, + serialize_feedback_dict, + serialize_run_dict, + serialized_feedback_operation_to_multipart_parts_and_context, + serialized_run_operation_to_multipart_parts_and_context, +) from langsmith._internal._serde import dumps_json as _dumps_json try: @@ -101,7 +115,6 @@ class ZoneInfo: # type: ignore[no-redef] WARNED_ATTACHMENTS = False EMPTY_SEQ: tuple[Dict, ...] = () BOUNDARY = uuid.uuid4().hex -MultipartParts = List[Tuple[str, Tuple[None, bytes, str, Dict[str, str]]]] URLLIB3_SUPPORTS_BLOCKSIZE = "key_blocksize" in signature(PoolKey).parameters @@ -1059,7 +1072,6 @@ def _run_transform( run: Union[ls_schemas.Run, dict, ls_schemas.RunLikeDict], update: bool = False, copy: bool = False, - attachments_collector: Optional[Dict[str, ls_schemas.Attachments]] = None, ) -> dict: """Transform the given run object into a dictionary representation. @@ -1067,9 +1079,6 @@ def _run_transform( run (Union[ls_schemas.Run, dict]): The run object to transform. update (bool, optional): Whether the payload is for an "update" event. copy (bool, optional): Whether to deepcopy run inputs/outputs. - attachments_collector (Optional[dict[str, ls_schemas.Attachments]]): - A dictionary to collect attachments. If not passed, attachments - will be dropped. Returns: dict: The transformed run object as a dictionary. @@ -1107,49 +1116,8 @@ def _run_transform( # Drop graph run_create["serialized"].pop("graph", None) - # Collect or drop attachments - if attachments := run_create.pop("attachments", None): - if attachments_collector is not None: - attachments_collector[run_create["id"]] = attachments - elif not WARNED_ATTACHMENTS: - WARNED_ATTACHMENTS = True - logger.warning( - "You're trying to submit a run with attachments, but your current" - " LangSmith integration doesn't support it. Please contact the " - " LangChain team at support at langchain" - " dot dev for assistance on how to upgrade." - ) - return run_create - def _feedback_transform( - self, - feedback: Union[ls_schemas.Feedback, dict], - ) -> dict: - """Transform the given feedback object into a dictionary representation. - - Args: - feedback (Union[ls_schemas.Feedback, dict]): The feedback object to transform. - update (bool, optional): Whether the payload is for an "update" event. - copy (bool, optional): Whether to deepcopy feedback inputs/outputs. - attachments_collector (Optional[dict[str, ls_schemas.Attachments]]): - A dictionary to collect attachments. If not passed, attachments - will be dropped. - - Returns: - dict: The transformed feedback object as a dictionary. - """ - if hasattr(feedback, "dict") and callable(getattr(feedback, "dict")): - feedback_create: dict = feedback.dict() # type: ignore - else: - feedback_create = cast(dict, feedback) - if "id" not in feedback_create: - feedback_create["id"] = uuid.uuid4() - elif isinstance(feedback_create["id"], str): - feedback_create["id"] = uuid.UUID(feedback_create["id"]) - - return feedback_create - @staticmethod def _insert_runtime_env(runs: Sequence[dict]) -> None: runtime_env = ls_env.get_runtime_environment() @@ -1240,20 +1208,26 @@ def create_run( } if not self._filter_for_sampling([run_create]): return - run_create = self._run_transform(run_create, copy=True) + if revision_id is not None: run_create["extra"]["metadata"]["revision_id"] = revision_id + run_create = self._run_transform( + run_create, + copy=False, + ) + self._insert_runtime_env([run_create]) if ( self.tracing_queue is not None # batch ingest requires trace_id and dotted_order to be set and run_create.get("trace_id") is not None and run_create.get("dotted_order") is not None ): - return self.tracing_queue.put( - TracingQueueItem(run_create["dotted_order"], "create", run_create) + serialized_op = serialize_run_dict("post", run_create) + self.tracing_queue.put( + TracingQueueItem(run_create["dotted_order"], serialized_op) ) - self._insert_runtime_env([run_create]) - self._create_run(run_create) + else: + self._create_run(run_create) def _create_run(self, run_create: dict): for api_url, api_key in self._write_api_urls.items(): @@ -1288,6 +1262,75 @@ def _hide_run_outputs(self, outputs: dict): return outputs return self._hide_outputs(outputs) + def _batch_ingest_run_ops( + self, + ops: List[SerializedRunOperation], + ) -> None: + ids_and_partial_body: dict[ + Literal["post", "patch"], list[tuple[str, bytes]] + ] = { + "post": [], + "patch": [], + } + + # form the partial body and ids + for op in ops: + if isinstance(op, SerializedRunOperation): + curr_dict = orjson.loads(op._none) + if op.inputs: + curr_dict["inputs"] = orjson.Fragment(op.inputs) + if op.outputs: + curr_dict["outputs"] = orjson.Fragment(op.outputs) + if op.events: + curr_dict["events"] = orjson.Fragment(op.events) + if op.attachments: + logger.warning( + "Attachments are not supported when use_multipart_endpoint " + "is False" + ) + ids_and_partial_body[op.operation].append( + (f"trace={op.trace_id},id={op.id}", orjson.dumps(curr_dict)) + ) + elif isinstance(op, SerializedFeedbackOperation): + logger.warning( + "Feedback operations are not supported in non-multipart mode" + ) + else: + logger.error("Unknown item type in tracing queue: %s", type(op)) + + # send the requests in batches + info = self.info + size_limit_bytes = (info.batch_ingest_config or {}).get( + "size_limit_bytes" + ) or _SIZE_LIMIT_BYTES + + body_chunks: DefaultDict[str, list] = collections.defaultdict(list) + context_ids: DefaultDict[str, list] = collections.defaultdict(list) + body_size = 0 + for key in cast(List[Literal["post", "patch"]], ["post", "patch"]): + body_deque = collections.deque(ids_and_partial_body[key]) + while body_deque: + if ( + body_size > 0 + and body_size + len(body_deque[0][1]) > size_limit_bytes + ): + self._post_batch_ingest_runs( + orjson.dumps(body_chunks), + _context=f"\n{key}: {'; '.join(context_ids[key])}", + ) + body_size = 0 + body_chunks.clear() + context_ids.clear() + curr_id, curr_body = body_deque.popleft() + body_size += len(curr_body) + body_chunks[key].append(orjson.Fragment(curr_body)) + context_ids[key].append(curr_id) + if body_size: + context = "; ".join(f"{k}: {'; '.join(v)}" for k, v in context_ids.items()) + self._post_batch_ingest_runs( + orjson.dumps(body_chunks), _context="\n" + context + ) + def batch_ingest_runs( self, create: Optional[ @@ -1325,22 +1368,13 @@ def batch_ingest_runs( if not create and not update: return # transform and convert to dicts - create_dicts = [self._run_transform(run) for run in create or EMPTY_SEQ] + create_dicts = [ + self._run_transform(run, copy=False) for run in create or EMPTY_SEQ + ] update_dicts = [ - self._run_transform(run, update=True) for run in update or EMPTY_SEQ + self._run_transform(run, update=True, copy=False) + for run in update or EMPTY_SEQ ] - # combine post and patch dicts where possible - if update_dicts and create_dicts: - create_by_id = {run["id"]: run for run in create_dicts} - standalone_updates: list[dict] = [] - for run in update_dicts: - if run["id"] in create_by_id: - create_by_id[run["id"]].update( - {k: v for k, v in run.items() if v is not None} - ) - else: - standalone_updates.append(run) - update_dicts = standalone_updates for run in create_dicts: if not run.get("trace_id") or not run.get("dotted_order"): raise ls_utils.LangSmithUserError( @@ -1352,64 +1386,29 @@ def batch_ingest_runs( "Batch ingest requires trace_id and dotted_order to be set." ) # filter out runs that are not sampled - if pre_sampled: - raw_body = { - "post": create_dicts, - "patch": update_dicts, - } - else: - raw_body = { - "post": self._filter_for_sampling(create_dicts), - "patch": self._filter_for_sampling(update_dicts, patch=True), - } - if not raw_body["post"] and not raw_body["patch"]: - return + if not pre_sampled: + create_dicts = self._filter_for_sampling(create_dicts) + update_dicts = self._filter_for_sampling(update_dicts, patch=True) - self._insert_runtime_env(raw_body["post"] + raw_body["patch"]) - info = self.info + if not create_dicts and not update_dicts: + return - size_limit_bytes = (info.batch_ingest_config or {}).get( - "size_limit_bytes" - ) or _SIZE_LIMIT_BYTES - # Get orjson fragments to avoid going over the max request size - partial_body = { - "post": [_dumps_json(run) for run in raw_body["post"]], - "patch": [_dumps_json(run) for run in raw_body["patch"]], - } - ids = { - "post": [ - f"trace={run.get('trace_id')},id={run.get('id')}" - for run in raw_body["post"] - ], - "patch": [ - f"trace={run.get('trace_id')},id={run.get('id')}" - for run in raw_body["patch"] - ], - } + self._insert_runtime_env(create_dicts + update_dicts) - body_chunks: DefaultDict[str, list] = collections.defaultdict(list) - context_ids: DefaultDict[str, list] = collections.defaultdict(list) - body_size = 0 - for key in ["post", "patch"]: - body = collections.deque(partial_body[key]) - ids_ = collections.deque(ids[key]) - while body: - if body_size > 0 and body_size + len(body[0]) > size_limit_bytes: - self._post_batch_ingest_runs( - orjson.dumps(body_chunks), - _context=f"\n{key}: {'; '.join(context_ids[key])}", + # convert to serialized ops + serialized_ops = cast( + List[SerializedRunOperation], + combine_serialized_queue_operations( + list( + itertools.chain( + (serialize_run_dict("post", run) for run in create_dicts), + (serialize_run_dict("patch", run) for run in update_dicts), ) - body_size = 0 - body_chunks.clear() - context_ids.clear() - body_size += len(body[0]) - body_chunks[key].append(orjson.Fragment(body.popleft())) - context_ids[key].append(ids_.popleft()) - if body_size: - context = "; ".join(f"{k}: {'; '.join(v)}" for k, v in context_ids.items()) - self._post_batch_ingest_runs( - orjson.dumps(body_chunks), _context="\n" + context - ) + ) + ), + ) + + self._batch_ingest_run_ops(serialized_ops) def _post_batch_ingest_runs(self, body: bytes, *, _context: str): for api_url, api_key in self._write_api_urls.items(): @@ -1436,6 +1435,25 @@ def _post_batch_ingest_runs(self, body: bytes, *, _context: str): except Exception: logger.warning(f"Failed to batch ingest runs: {repr(e)}") + def _multipart_ingest_ops( + self, ops: list[Union[SerializedRunOperation, SerializedFeedbackOperation]] + ) -> None: + parts: list[MultipartPartsAndContext] = [] + for op in ops: + if isinstance(op, SerializedRunOperation): + parts.append( + serialized_run_operation_to_multipart_parts_and_context(op) + ) + elif isinstance(op, SerializedFeedbackOperation): + parts.append( + serialized_feedback_operation_to_multipart_parts_and_context(op) + ) + else: + logger.error("Unknown operation type in tracing queue: %s", type(op)) + acc_multipart = join_multipart_parts_and_context(parts) + if acc_multipart: + self._send_multipart_req(acc_multipart) + def multipart_ingest( self, create: Optional[ @@ -1444,7 +1462,6 @@ def multipart_ingest( update: Optional[ Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]] ] = None, - feedback: Optional[Sequence[Union[ls_schemas.Feedback, Dict]]] = None, *, pre_sampled: bool = False, ) -> None: @@ -1471,19 +1488,13 @@ def multipart_ingest( - The run objects MUST contain the dotted_order and trace_id fields to be accepted by the API. """ - if not (create or update or feedback): + if not (create or update): return # transform and convert to dicts - all_attachments: Dict[str, ls_schemas.Attachments] = {} - create_dicts = [ - self._run_transform(run, attachments_collector=all_attachments) - for run in create or EMPTY_SEQ - ] + create_dicts = [self._run_transform(run) for run in create or EMPTY_SEQ] update_dicts = [ - self._run_transform(run, update=True, attachments_collector=all_attachments) - for run in update or EMPTY_SEQ + self._run_transform(run, update=True) for run in update or EMPTY_SEQ ] - feedback_dicts = [self._feedback_transform(f) for f in feedback or EMPTY_SEQ] # require trace_id and dotted_order if create_dicts: for run in create_dicts: @@ -1521,75 +1532,28 @@ def multipart_ingest( if not pre_sampled: create_dicts = self._filter_for_sampling(create_dicts) update_dicts = self._filter_for_sampling(update_dicts, patch=True) - if not create_dicts and not update_dicts and not feedback_dicts: + if not create_dicts and not update_dicts: return # insert runtime environment self._insert_runtime_env(create_dicts) self._insert_runtime_env(update_dicts) - # send the runs in multipart requests - acc_context: List[str] = [] - acc_parts: MultipartParts = [] - for event, payloads in ( - ("post", create_dicts), - ("patch", update_dicts), - ("feedback", feedback_dicts), - ): - for payload in payloads: - # collect fields to be sent as separate parts - fields = [ - ("inputs", payload.pop("inputs", None)), - ("outputs", payload.pop("outputs", None)), - ("events", payload.pop("events", None)), - ("feedback", payload.pop("feedback", None)), - ] - # encode the main run payload - payloadb = _dumps_json(payload) - acc_parts.append( - ( - f"{event}.{payload['id']}", - ( - None, - payloadb, - "application/json", - {"Content-Length": str(len(payloadb))}, - ), - ) - ) - # encode the fields we collected - for key, value in fields: - if value is None: - continue - valb = _dumps_json(value) - acc_parts.append( - ( - f"{event}.{payload['id']}.{key}", - ( - None, - valb, - "application/json", - {"Content-Length": str(len(valb))}, - ), - ), - ) - # encode the attachments - if attachments := all_attachments.pop(payload["id"], None): - for n, (ct, ba) in attachments.items(): - acc_parts.append( - ( - f"attachment.{payload['id']}.{n}", - (None, ba, ct, {"Content-Length": str(len(ba))}), - ) - ) - # compute context - acc_context.append( - f"trace={payload.get('trace_id')},id={payload.get('id')}" + + # format as serialized operations + serialized_ops = combine_serialized_queue_operations( + list( + itertools.chain( + (serialize_run_dict("post", run) for run in create_dicts), + (serialize_run_dict("patch", run) for run in update_dicts), ) - # send the request - self._send_multipart_req(acc_parts, _context="; ".join(acc_context)) + ) + ) + + # sent the runs in multipart requests + self._multipart_ingest_ops(serialized_ops) - def _send_multipart_req( - self, parts: MultipartParts, *, _context: str, attempts: int = 3 - ): + def _send_multipart_req(self, acc: MultipartPartsAndContext, *, attempts: int = 3): + parts = acc.parts + _context = acc.context for api_url, api_key in self._write_api_urls.items(): for idx in range(1, attempts + 1): try: @@ -1680,6 +1644,12 @@ def update_run( "session_id": kwargs.pop("session_id", None), "session_name": kwargs.pop("session_name", None), } + use_multipart = ( + self.tracing_queue is not None + # batch ingest requires trace_id and dotted_order to be set + and data["trace_id"] is not None + and data["dotted_order"] is not None + ) if not self._filter_for_sampling([data], patch=True): return if end_time is not None: @@ -1691,20 +1661,19 @@ def update_run( if inputs is not None: data["inputs"] = self._hide_run_inputs(inputs) if outputs is not None: - outputs = ls_utils.deepish_copy(outputs) + if not use_multipart: + outputs = ls_utils.deepish_copy(outputs) data["outputs"] = self._hide_run_outputs(outputs) if events is not None: data["events"] = events - if ( - self.tracing_queue is not None - # batch ingest requires trace_id and dotted_order to be set - and data["trace_id"] is not None - and data["dotted_order"] is not None - ): - return self.tracing_queue.put( - TracingQueueItem(data["dotted_order"], "update", data) + if use_multipart and self.tracing_queue is not None: + # not collecting attachments currently, use empty dict + serialized_op = serialize_run_dict(operation="patch", payload=data) + self.tracing_queue.put( + TracingQueueItem(data["dotted_order"], serialized_op) ) - return self._update_run(data) + else: + self._update_run(data) def _update_run(self, run_update: dict) -> None: for api_url, api_key in self._write_api_urls.items(): @@ -4334,7 +4303,6 @@ def create_feedback( feedback_group_id=_ensure_uuid(feedback_group_id, accept_null=True), ) - feedback_block = _dumps_json(feedback.dict(exclude_none=True)) use_multipart = (self.info.batch_ingest_config or {}).get( "use_multipart_endpoint", False ) @@ -4346,10 +4314,12 @@ def create_feedback( and self.tracing_queue is not None and feedback.trace_id is not None ): + serialized_op = serialize_feedback_dict(feedback) self.tracing_queue.put( - TracingQueueItem(str(feedback.id), "feedback", feedback) + TracingQueueItem(str(feedback.id), serialized_op) ) else: + feedback_block = _dumps_json(feedback.dict(exclude_none=True)) self.request_with_retries( "POST", "/feedback", diff --git a/python/pyproject.toml b/python/pyproject.toml index 5bf9b11aa..f062863d1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langsmith" -version = "0.1.138" +version = "0.1.139rc1" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." authors = ["LangChain "] license = "MIT" diff --git a/python/tests/integration_tests/test_client.py b/python/tests/integration_tests/test_client.py index d5a39ab42..0cf762859 100644 --- a/python/tests/integration_tests/test_client.py +++ b/python/tests/integration_tests/test_client.py @@ -682,20 +682,7 @@ def test_batch_ingest_runs( }, ] if use_multipart_endpoint: - feedback = [ - { - "run_id": run["id"], - "trace_id": run["trace_id"], - "key": "test_key", - "score": 0.9, - "value": "test_value", - "comment": "test_comment", - } - for run in runs_to_create - ] - langchain_client.multipart_ingest( - create=runs_to_create, update=runs_to_update, feedback=feedback - ) + langchain_client.multipart_ingest(create=runs_to_create, update=runs_to_update) else: langchain_client.batch_ingest_runs(create=runs_to_create, update=runs_to_update) runs = [] @@ -735,34 +722,6 @@ def test_batch_ingest_runs( assert run3.inputs == {"input1": 1, "input2": 2} assert run3.error == "error" - if use_multipart_endpoint: - feedbacks = list( - langchain_client.list_feedback(run_ids=[run.id for run in runs]) - ) - assert len(feedbacks) == 3 - for feedback in feedbacks: - assert feedback.key == "test_key" - assert feedback.score == 0.9 - assert feedback.value == "test_value" - assert feedback.comment == "test_comment" - - -""" -Multipart partitions: -- num created: [0], [1], >1 -- num updated: [0], [1], >1 -- num created + num updated: [0], [1], >1 -- individual id: created only, updated only, both -- [updated is root trace], [updated is run] - -Error cases: -- dual created -- dual updated -- created and dual updated [? maybe not an error] -- dual created and single updated -- retry doesn't fail -""" - def test_multipart_ingest_empty( langchain_client: Client, caplog: pytest.LogCaptureFixture diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index dd212373e..9ff1a8eef 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -288,9 +288,6 @@ def test_create_run_unicode() -> None: def test_create_run_mutate( use_multipart_endpoint: bool, monkeypatch: pytest.MonkeyPatch ) -> None: - if use_multipart_endpoint: - monkeypatch.setenv("LANGSMITH_FF_MULTIPART", "true") - # TODO remove this when removing FF inputs = {"messages": ["hi"], "mygen": (i for i in range(10))} session = mock.Mock() session.request = mock.Mock() @@ -354,7 +351,6 @@ def test_create_run_mutate( parser = MultipartParser(data, boundary) parts.extend(parser.parts()) - assert len(parts) == 3 assert [p.name for p in parts] == [ f"post.{id_}", f"post.{id_}.inputs", @@ -1069,18 +1065,7 @@ def test_batch_ingest_run_splits_large_batches( ] if use_multipart_endpoint: - feedback = [ - { - "run_id": run_id, - "trace_id": run_id, - "key": "test_key", - "score": 0.9, - "value": "test_value", - "comment": "test_comment", - } - for run_id in run_ids - ] - client.multipart_ingest(create=posts, update=patches, feedback=feedback) + client.multipart_ingest(create=posts, update=patches) # multipart endpoint should only send one request expected_num_requests = 1 # count the number of POST requests diff --git a/python/tests/unit_tests/test_operations.py b/python/tests/unit_tests/test_operations.py new file mode 100644 index 000000000..a6b5cdeb3 --- /dev/null +++ b/python/tests/unit_tests/test_operations.py @@ -0,0 +1,112 @@ +import orjson + +from langsmith._internal._operations import ( + SerializedFeedbackOperation, + SerializedRunOperation, + combine_serialized_queue_operations, +) + + +def test_combine_serialized_queue_operations(): + # Arrange + serialized_run_operations = [ + SerializedRunOperation( + operation="post", + id="id1", + trace_id="trace_id1", + _none=orjson.dumps({"a": 1}), + inputs="inputs1", + outputs="outputs1", + events="events1", + attachments=None, + ), + SerializedRunOperation( + operation="patch", + id="id1", + trace_id="trace_id1", + _none=orjson.dumps({"b": "2"}), + inputs="inputs1-patched", + outputs="outputs1-patched", + events="events1", + attachments=None, + ), + SerializedFeedbackOperation( + id="id2", + trace_id="trace_id2", + feedback="feedback2", + ), + SerializedRunOperation( + operation="post", + id="id3", + trace_id="trace_id3", + _none="none3", + inputs="inputs3", + outputs="outputs3", + events="events3", + attachments=None, + ), + SerializedRunOperation( + operation="patch", + id="id4", + trace_id="trace_id4", + _none="none4", + inputs="inputs4-patched", + outputs="outputs4-patched", + events="events4", + attachments=None, + ), + SerializedRunOperation( + operation="post", + id="id5", + trace_id="trace_id5", + _none="none5", + inputs="inputs5", + outputs=None, + events="events5", + attachments=None, + ), + SerializedRunOperation( + operation="patch", + id="id5", + trace_id="trace_id5", + _none=None, + inputs=None, + outputs="outputs5-patched", + events=None, + attachments=None, + ), + ] + + # Act + result = combine_serialized_queue_operations(serialized_run_operations) + + # Assert + assert result == [ + # merged 1+2 + SerializedRunOperation( + operation="post", + id="id1", + trace_id="trace_id1", + _none=orjson.dumps({"a": 1, "b": "2"}), + inputs="inputs1-patched", + outputs="outputs1-patched", + events="events1", + attachments=None, + ), + # 4 passthrough + serialized_run_operations[3], + # merged 6+7 + SerializedRunOperation( + operation="post", + id="id5", + trace_id="trace_id5", + _none="none5", + inputs="inputs5", + outputs="outputs5-patched", + events="events5", + attachments=None, + ), + # 3,5 are passthrough in that order + serialized_run_operations[2], + serialized_run_operations[4], + ]