Skip to content

Commit

Permalink
Send full input to on_chain_start for Runnables
Browse files Browse the repository at this point in the history
  • Loading branch information
hughcrt committed Mar 4, 2024
1 parent b051bba commit 33d7531
Showing 1 changed file with 36 additions and 38 deletions.
74 changes: 36 additions & 38 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,7 @@ def astream_log(
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[RunLogPatch]:
...
) -> AsyncIterator[RunLogPatch]: ...

@overload
def astream_log(
Expand All @@ -621,8 +620,7 @@ def astream_log(
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[RunLog]:
...
) -> AsyncIterator[RunLog]: ...

async def astream_log(
self,
Expand Down Expand Up @@ -1493,9 +1491,23 @@ def _transform_stream_with_config(

config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)

for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_input = final_input + ichunk # type: ignore
except TypeError:
final_input = ichunk
final_input_supported = False
else:
final_input = ichunk

run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
final_input,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
)
Expand Down Expand Up @@ -1525,18 +1537,6 @@ def _transform_stream_with_config(
final_output = chunk
except StopIteration:
pass
for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_input = final_input + ichunk # type: ignore
except TypeError:
final_input = ichunk
final_input_supported = False
else:
final_input = ichunk
except BaseException as e:
run_manager.on_chain_error(e, inputs=final_input)
raise
Expand Down Expand Up @@ -1580,9 +1580,23 @@ async def _atransform_stream_with_config(

config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)

async for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_input = final_input + ichunk # type: ignore[operator]
except TypeError:
final_input = ichunk
final_input_supported = False
else:
final_input = ichunk

run_manager = await callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
final_input,
run_type=run_type,
name=config.get("run_name") or self.get_name(),
)
Expand Down Expand Up @@ -1628,18 +1642,6 @@ async def _atransform_stream_with_config(
final_output = chunk
except StopAsyncIteration:
pass
async for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else:
try:
final_input = final_input + ichunk # type: ignore[operator]
except TypeError:
final_input = ichunk
final_input_supported = False
else:
final_input = ichunk
except BaseException as e:
await run_manager.on_chain_error(e, inputs=final_input)
raise
Expand Down Expand Up @@ -4381,29 +4383,25 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
@overload
def chain(
func: Callable[[Input], Coroutine[Any, Any, Output]],
) -> Runnable[Input, Output]:
...
) -> Runnable[Input, Output]: ...


@overload
def chain(
func: Callable[[Input], Iterator[Output]],
) -> Runnable[Input, Output]:
...
) -> Runnable[Input, Output]: ...


@overload
def chain(
func: Callable[[Input], AsyncIterator[Output]],
) -> Runnable[Input, Output]:
...
) -> Runnable[Input, Output]: ...


@overload
def chain(
func: Callable[[Input], Output],
) -> Runnable[Input, Output]:
...
) -> Runnable[Input, Output]: ...


def chain(
Expand Down

0 comments on commit 33d7531

Please sign in to comment.