Skip to content

Commit 9aa4df7

Browse files
fix(python): Properly respect max_concurrency for aevaluate (#1613)
Co-authored-by: Bagatur <[email protected]>
1 parent 0999d18 commit 9aa4df7

File tree

2 files changed

+111
-23
lines changed

2 files changed

+111
-23
lines changed

python/langsmith/evaluation/_arunner.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -778,31 +778,43 @@ async def awith_predictions_and_evaluators(
778778
"""
779779
evaluators = _resolve_evaluators(evaluators)
780780

781-
if not hasattr(self, "_evaluator_executor"):
782-
self._evaluator_executor = cf.ThreadPoolExecutor(max_workers=4)
781+
if not hasattr(self, "_evaluation_feedback_executor"):
782+
self._evaluation_feedback_executor = cf.ThreadPoolExecutor(max_workers=4)
783+
784+
traceable_target = _ensure_async_traceable(target)
785+
786+
async def process_example(example: schemas.Example):
787+
# Yield the coroutine to be awaited later
788+
pred = await _aforward(
789+
traceable_target,
790+
self._get_example_with_readers(example),
791+
self.experiment_name,
792+
self._metadata,
793+
self.client,
794+
_include_attachments(target),
795+
)
796+
example, run = pred["example"], pred["run"]
797+
result = await self._arun_evaluators(
798+
evaluators,
799+
{
800+
"run": run,
801+
"example": example,
802+
"evaluation_results": {"results": []},
803+
},
804+
feedback_executor=self._evaluation_feedback_executor,
805+
)
806+
return result
783807

784808
async def process_examples():
785809
"""Create a single task per example.
786810
787811
That task is to run the target function and all the evaluators
788812
sequentially.
789813
"""
790-
async for pred in self._apredict(
791-
target,
792-
max_concurrency=max_concurrency,
793-
include_attachments=_include_attachments(target),
794-
):
795-
example, run = pred["example"], pred["run"]
796-
result = self._arun_evaluators(
797-
evaluators,
798-
{
799-
"run": run,
800-
"example": example,
801-
"evaluation_results": {"results": []},
802-
},
803-
executor=self._evaluator_executor,
804-
)
805-
yield result
814+
async for example in await self.aget_examples():
815+
yield process_example(example)
816+
817+
await self._aend()
806818

807819
# Run the per-example tasks with max-concurrency
808820
# This guarantees that max_concurrency is the upper limit
@@ -944,13 +956,13 @@ async def _ascore(
944956
evaluators: Sequence[RunEvaluator],
945957
max_concurrency: Optional[int] = None,
946958
) -> AsyncIterator[ExperimentResultRow]:
947-
with cf.ThreadPoolExecutor(max_workers=4) as executor:
959+
with cf.ThreadPoolExecutor(max_workers=4) as feedback_executor:
948960

949961
async def score_all():
950962
async for current_results in self.aget_results():
951963
# Yield the coroutine to be awaited later in aiter_with_concurrency
952964
yield self._arun_evaluators(
953-
evaluators, current_results, executor=executor
965+
evaluators, current_results, feedback_executor=feedback_executor
954966
)
955967

956968
async for result in aitertools.aiter_with_concurrency(
@@ -962,7 +974,7 @@ async def _arun_evaluators(
962974
self,
963975
evaluators: Sequence[RunEvaluator],
964976
current_results: ExperimentResultRow,
965-
executor: cf.ThreadPoolExecutor,
977+
feedback_executor: cf.ThreadPoolExecutor,
966978
) -> ExperimentResultRow:
967979
current_context = rh.get_tracing_context()
968980
metadata = {
@@ -996,7 +1008,7 @@ async def _run_single_evaluator(evaluator: RunEvaluator):
9961008

9971009
if self._upload_results:
9981010
self.client._log_evaluation_feedback(
999-
evaluator_response, run=run, _executor=executor
1011+
evaluator_response, run=run, _executor=feedback_executor
10001012
)
10011013
return selected_results
10021014
except Exception as e:
@@ -1019,7 +1031,7 @@ async def _run_single_evaluator(evaluator: RunEvaluator):
10191031
)
10201032
if self._upload_results:
10211033
self.client._log_evaluation_feedback(
1022-
error_response, run=run, _executor=executor
1034+
error_response, run=run, _executor=feedback_executor
10231035
)
10241036
return selected_results
10251037
except Exception as e2:

python/tests/evaluation/test_evaluation.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,3 +472,79 @@ async def predict(inputs: dict):
472472
predict,
473473
data=(_ for _ in range(0)),
474474
)
475+
476+
477+
async def test_aevaluate_large_dataset_and_concurrency():
478+
client = Client()
479+
_ = client.clone_public_dataset(
480+
"https://smith.langchain.com/public/2bbf4a10-c3d5-4868-9e96-400df97fed69/d"
481+
)
482+
dataset_name = "Evaluate Examples"
483+
484+
async def mock_chat_completion(*, messages):
485+
await asyncio.sleep(1)
486+
return {
487+
"role": "assistant",
488+
"content": "Still thinking...",
489+
}
490+
491+
def simulate_conversation_turn(*, existing, model_response):
492+
return existing + [
493+
model_response,
494+
{"role": "human", "content": "Think harder!"},
495+
]
496+
497+
# Will be traced by default
498+
async def target(inputs: dict) -> dict:
499+
messages = [
500+
{
501+
"role": "system",
502+
"content": "Come up with a math equation that solves the puzzle.",
503+
},
504+
# This dataset has inputs as a dict with a "statement" key
505+
{"role": "user", "content": "foo"},
506+
]
507+
res = await mock_chat_completion(model="gpt-4o-mini", messages=messages)
508+
messages = simulate_conversation_turn(existing=messages, model_response=res)
509+
510+
return {"equation": res}
511+
512+
async def mock_evaluator_chat_completion(*, model, messages):
513+
await asyncio.sleep(2)
514+
return {
515+
"role": "assistant",
516+
"content": str(0.5),
517+
}
518+
519+
async def mock_correctness_evaluator(outputs: dict, reference_outputs: dict):
520+
messages = [
521+
{"role": "system", "content": "Assign a score to the following output."},
522+
{
523+
"role": "user",
524+
"content": f"""
525+
Actual: {outputs["equation"]}
526+
""",
527+
},
528+
]
529+
res = await mock_evaluator_chat_completion(model="o3-mini", messages=messages)
530+
return {
531+
"key": "correctness",
532+
"score": float(res["content"]),
533+
"comment": "The answer was a good attempt, but incorrect.",
534+
}
535+
536+
client = Client()
537+
538+
start = time.time()
539+
540+
await client.aevaluate(
541+
target,
542+
data=client.list_examples(dataset_name=dataset_name, as_of="test_version"),
543+
evaluators=[
544+
mock_correctness_evaluator,
545+
],
546+
max_concurrency=3,
547+
)
548+
549+
finish_time = time.time()
550+
assert (finish_time - start) <= 8.5

0 commit comments

Comments
 (0)