Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Nov 12, 2024
1 parent 4a24f9f commit b3b841f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 20 deletions.
53 changes: 41 additions & 12 deletions python/langsmith/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def __init__(
arguments, and returns a dict or `ComparisonEvaluationResult`.
"""
func = _normalize_evaluator_func(func)
if afunc:
afunc = _normalize_evaluator_func(afunc) # type: ignore[assignment]

wraps(func)(self)
from langsmith import run_helpers # type: ignore
Expand Down Expand Up @@ -638,7 +640,10 @@ def comparison_evaluator(

def _normalize_evaluator_func(
func: Callable,
) -> Callable[[Run, Optional[Example]], _RUNNABLE_OUTPUT]:
) -> Union[
Callable[[Run, Optional[Example]], _RUNNABLE_OUTPUT],
Callable[[Run, Optional[Example]], Awaitable[_RUNNABLE_OUTPUT]],
]:
# for backwards compatibility, if args are untyped we assume they correspond to
# Run and Example:
if not (type_hints := get_type_hints(func)):
Expand All @@ -660,17 +665,41 @@ def _normalize_evaluator_func(
if not (
num_positional in (2, 3) or (num_positional <= 3 and has_positional_var)
):
msg = ""
msg = (
"Invalid evaluator function. Expected to take either 2 or 3 positional "
"arguments. Please see "
"https://docs.smith.langchain.com/evaluation/how_to_guides/evaluation/evaluate_llm_application#use-custom-evaluators" # noqa: E501
)
raise ValueError(msg)

def wrapper(run: Run, example: Example) -> _RUNNABLE_OUTPUT:
args = (example.inputs, run.outputs or {}, example.outputs or {})
if has_positional_var:
return func(*args)
else:
return func(*args[:num_positional])
if inspect.iscoroutinefunction(func):

wrapper.__name__ = (
getattr(func, "__name__") if hasattr(func, "__name__") else wrapper.__name__
)
return wrapper
async def awrapper(run: Run, example: Example) -> _RUNNABLE_OUTPUT:
args = (example.inputs, run.outputs or {}, example.outputs or {})
if has_positional_var:
return await func(*args)
else:
return await func(*args[:num_positional])

awrapper.__name__ = (
getattr(func, "__name__")
if hasattr(func, "__name__")
else awrapper.__name__
)
return awrapper # type: ignore[return-value]

else:

def wrapper(run: Run, example: Example) -> _RUNNABLE_OUTPUT:
args = (example.inputs, run.outputs or {}, example.outputs or {})
if has_positional_var:
return func(*args)
else:
return func(*args[:num_positional])

wrapper.__name__ = (
getattr(func, "__name__")
if hasattr(func, "__name__")
else wrapper.__name__
)
return wrapper # type: ignore[return-value]
47 changes: 39 additions & 8 deletions python/tests/unit_tests/evaluation/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,26 @@ def score_value_first(run, example):
ordering_of_stuff.append("evaluate")
return {"score": 0.3}

def score_unpacked_inputs_outputs(inputs: dict, outputs: dict):
ordering_of_stuff.append("evaluate")
return {"score": outputs["output"]}

def score_unpacked_inputs_outputs_reference(
inputs: dict, outputs: dict, reference_outputs: dict
):
ordering_of_stuff.append("evaluate")
return {"score": reference_outputs["answer"]}

evaluators = [
score_value_first,
score_unpacked_inputs_outputs,
score_unpacked_inputs_outputs_reference,
]
results = evaluate(
predict,
client=client,
data=dev_split,
evaluators=[score_value_first],
evaluators=evaluators,
num_repetitions=NUM_REPETITIONS,
blocking=blocking,
)
Expand Down Expand Up @@ -219,14 +234,14 @@ def score_value_first(run, example):
assert fake_request.created_session
_wait_until(lambda: fake_request.runs)
N_PREDS = SPLIT_SIZE * NUM_REPETITIONS
_wait_until(lambda: len(ordering_of_stuff) == N_PREDS * 2)
_wait_until(lambda: len(ordering_of_stuff) == (N_PREDS * (len(evaluators) + 1)))
_wait_until(lambda: slow_index is not None)
# Want it to be interleaved
assert ordering_of_stuff != ["predict"] * N_PREDS + ["evaluate"] * N_PREDS
assert ordering_of_stuff[:N_PREDS] != ["predict"] * N_PREDS

# It's delayed, so it'll be the penultimate event
# Will run all other preds and evals, then this, then the last eval
assert slow_index == (N_PREDS * 2) - 2
assert slow_index == (len(evaluators) + 1) * (N_PREDS - 1)

def score_value(run, example):
return {"score": 0.7}
Expand Down Expand Up @@ -347,11 +362,27 @@ async def score_value_first(run, example):
ordering_of_stuff.append("evaluate")
return {"score": 0.3}

async def score_unpacked_inputs_outputs(inputs: dict, outputs: dict):
ordering_of_stuff.append("evaluate")
return {"score": outputs["output"]}

async def score_unpacked_inputs_outputs_reference(
inputs: dict, outputs: dict, reference_outputs: dict
):
ordering_of_stuff.append("evaluate")
return {"score": reference_outputs["answer"]}

evaluators = [
score_value_first,
score_unpacked_inputs_outputs,
score_unpacked_inputs_outputs_reference,
]

results = await aevaluate(
predict,
client=client,
data=dev_split,
evaluators=[score_value_first],
evaluators=evaluators,
num_repetitions=NUM_REPETITIONS,
blocking=blocking,
)
Expand Down Expand Up @@ -387,14 +418,14 @@ async def score_value_first(run, example):
assert fake_request.created_session
_wait_until(lambda: fake_request.runs)
N_PREDS = SPLIT_SIZE * NUM_REPETITIONS
_wait_until(lambda: len(ordering_of_stuff) == N_PREDS * 2)
_wait_until(lambda: len(ordering_of_stuff) == N_PREDS * (len(evaluators) + 1))
_wait_until(lambda: slow_index is not None)
# Want it to be interleaved
assert ordering_of_stuff != ["predict"] * N_PREDS + ["evaluate"] * N_PREDS
assert ordering_of_stuff[:N_PREDS] != ["predict"] * N_PREDS
assert slow_index is not None
# It's delayed, so it'll be the penultimate event
# Will run all other preds and evals, then this, then the last eval
assert slow_index == (N_PREDS * 2) - 2
assert slow_index == (N_PREDS - 1) * (len(evaluators) + 1)

assert fake_request.created_session["name"]

Expand Down

0 comments on commit b3b841f

Please sign in to comment.