@@ -778,31 +778,43 @@ async def awith_predictions_and_evaluators(
778
778
"""
779
779
evaluators = _resolve_evaluators (evaluators )
780
780
781
- if not hasattr (self , "_evaluator_executor" ):
782
- self ._evaluator_executor = cf .ThreadPoolExecutor (max_workers = 4 )
781
+ if not hasattr (self , "_evaluation_feedback_executor" ):
782
+ self ._evaluation_feedback_executor = cf .ThreadPoolExecutor (max_workers = 4 )
783
+
784
+ traceable_target = _ensure_async_traceable (target )
785
+
786
+ async def process_example (example : schemas .Example ):
787
+ # Yield the coroutine to be awaited later
788
+ pred = await _aforward (
789
+ traceable_target ,
790
+ self ._get_example_with_readers (example ),
791
+ self .experiment_name ,
792
+ self ._metadata ,
793
+ self .client ,
794
+ _include_attachments (target ),
795
+ )
796
+ example , run = pred ["example" ], pred ["run" ]
797
+ result = await self ._arun_evaluators (
798
+ evaluators ,
799
+ {
800
+ "run" : run ,
801
+ "example" : example ,
802
+ "evaluation_results" : {"results" : []},
803
+ },
804
+ feedback_executor = self ._evaluation_feedback_executor ,
805
+ )
806
+ return result
783
807
784
808
async def process_examples ():
785
809
"""Create a single task per example.
786
810
787
811
That task is to run the target function and all the evaluators
788
812
sequentially.
789
813
"""
790
- async for pred in self ._apredict (
791
- target ,
792
- max_concurrency = max_concurrency ,
793
- include_attachments = _include_attachments (target ),
794
- ):
795
- example , run = pred ["example" ], pred ["run" ]
796
- result = self ._arun_evaluators (
797
- evaluators ,
798
- {
799
- "run" : run ,
800
- "example" : example ,
801
- "evaluation_results" : {"results" : []},
802
- },
803
- executor = self ._evaluator_executor ,
804
- )
805
- yield result
814
+ async for example in await self .aget_examples ():
815
+ yield process_example (example )
816
+
817
+ await self ._aend ()
806
818
807
819
# Run the per-example tasks with max-concurrency
808
820
# This guarantees that max_concurrency is the upper limit
@@ -944,13 +956,13 @@ async def _ascore(
944
956
evaluators : Sequence [RunEvaluator ],
945
957
max_concurrency : Optional [int ] = None ,
946
958
) -> AsyncIterator [ExperimentResultRow ]:
947
- with cf .ThreadPoolExecutor (max_workers = 4 ) as executor :
959
+ with cf .ThreadPoolExecutor (max_workers = 4 ) as feedback_executor :
948
960
949
961
async def score_all ():
950
962
async for current_results in self .aget_results ():
951
963
# Yield the coroutine to be awaited later in aiter_with_concurrency
952
964
yield self ._arun_evaluators (
953
- evaluators , current_results , executor = executor
965
+ evaluators , current_results , feedback_executor = feedback_executor
954
966
)
955
967
956
968
async for result in aitertools .aiter_with_concurrency (
@@ -962,7 +974,7 @@ async def _arun_evaluators(
962
974
self ,
963
975
evaluators : Sequence [RunEvaluator ],
964
976
current_results : ExperimentResultRow ,
965
- executor : cf .ThreadPoolExecutor ,
977
+ feedback_executor : cf .ThreadPoolExecutor ,
966
978
) -> ExperimentResultRow :
967
979
current_context = rh .get_tracing_context ()
968
980
metadata = {
@@ -996,7 +1008,7 @@ async def _run_single_evaluator(evaluator: RunEvaluator):
996
1008
997
1009
if self ._upload_results :
998
1010
self .client ._log_evaluation_feedback (
999
- evaluator_response , run = run , _executor = executor
1011
+ evaluator_response , run = run , _executor = feedback_executor
1000
1012
)
1001
1013
return selected_results
1002
1014
except Exception as e :
@@ -1019,7 +1031,7 @@ async def _run_single_evaluator(evaluator: RunEvaluator):
1019
1031
)
1020
1032
if self ._upload_results :
1021
1033
self .client ._log_evaluation_feedback (
1022
- error_response , run = run , _executor = executor
1034
+ error_response , run = run , _executor = feedback_executor
1023
1035
)
1024
1036
return selected_results
1025
1037
except Exception as e2 :
0 commit comments