Skip to content

Commit 132825d

Browse files
committed
InferenceService refactoring (now InferenceRunner)
1 parent ce38aa7 commit 132825d

File tree

12 files changed

+226
-82
lines changed

12 files changed

+226
-82
lines changed

Diff for: x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
7070
import org.elasticsearch.xpack.esql.index.EsIndex;
7171
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
72-
import org.elasticsearch.xpack.esql.inference.InferenceService;
72+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
7373
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
7474
import org.elasticsearch.xpack.esql.parser.QueryParam;
7575
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
@@ -381,18 +381,18 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() {
381381
mock(ClusterService.class),
382382
mock(IndexNameExpressionResolver.class),
383383
null,
384-
mockInferenceService()
384+
mockInferenceRunner()
385385
);
386386

387387
@SuppressWarnings("unchecked")
388-
private static InferenceService mockInferenceService() {
389-
InferenceService inferenceService = mock(InferenceService.class);
388+
private static InferenceRunner mockInferenceRunner() {
389+
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
390390
doAnswer(i -> {
391391
i.getArgument(1, ActionListener.class).onResponse(emptyInferenceResolution());
392392
return null;
393-
}).when(inferenceService).resolveInferences(any(), any());
393+
}).when(inferenceRunner).resolveInferenceIds(any(), any());
394394

395-
return inferenceService;
395+
return inferenceRunner;
396396
}
397397

398398
private EsqlTestUtils() {}

Diff for: x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolution.java

+4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ public Collection<ResolvedInference> resolvedInferences() {
3737
return resolvedInferences.values();
3838
}
3939

40+
public boolean hasError() {
41+
return errors.isEmpty() == false;
42+
}
43+
4044
public String getError(String inferenceId) {
4145
final String error = errors.get(inferenceId);
4246
if (error != null) {

Diff for: x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java renamed to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java

+20-19
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.inference;
99

1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.CountDownActionListener;
1112
import org.elasticsearch.client.internal.Client;
1213
import org.elasticsearch.common.util.concurrent.ThreadContext;
1314
import org.elasticsearch.inference.TaskType;
@@ -18,34 +19,38 @@
1819

1920
import java.util.List;
2021
import java.util.Set;
21-
import java.util.concurrent.CountDownLatch;
2222
import java.util.stream.Collectors;
2323

24-
public class InferenceService {
24+
public class InferenceRunner {
2525

2626
private final Client client;
2727

28-
public InferenceService(Client client) {
28+
public InferenceRunner(Client client) {
2929
this.client = client;
3030
}
3131

3232
public ThreadContext getThreadContext() {
3333
return client.threadPool().getThreadContext();
3434
}
3535

36-
public void resolveInferences(List<InferencePlan> plans, ActionListener<InferenceResolution> listener) {
36+
public void resolveInferenceIds(List<InferencePlan> plans, ActionListener<InferenceResolution> listener) {
37+
resolveInferenceIds(plans.stream().map(InferenceRunner::planInferenceId).collect(Collectors.toSet()), listener);
3738

38-
if (plans.isEmpty()) {
39+
}
40+
41+
private void resolveInferenceIds(Set<String> inferenceIds, ActionListener<InferenceResolution> listener) {
42+
43+
if (inferenceIds.isEmpty()) {
3944
listener.onResponse(InferenceResolution.EMPTY);
4045
return;
4146
}
4247

43-
Set<String> inferenceIds = plans.stream()
44-
.map(p -> p.inferenceId().fold(FoldContext.small()).toString())
45-
.collect(Collectors.toSet());
48+
final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();
4649

47-
CountDownLatch countDownLatch = new CountDownLatch(inferenceIds.size());
48-
InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();
50+
final CountDownActionListener countdownListener = new CountDownActionListener(
51+
inferenceIds.size(),
52+
ActionListener.wrap(_r -> listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure)
53+
);
4954

5055
for (var inferenceId : inferenceIds) {
5156
client.execute(
@@ -54,21 +59,17 @@ public void resolveInferences(List<InferencePlan> plans, ActionListener<Inferenc
5459
ActionListener.wrap(r -> {
5560
ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType());
5661
inferenceResolutionBuilder.withResolvedInference(resolvedInference);
57-
countDownLatch.countDown();
62+
countdownListener.onResponse(null);
5863
}, e -> {
5964
inferenceResolutionBuilder.withError(inferenceId, e.getMessage());
60-
countDownLatch.countDown();
65+
countdownListener.onResponse(null);
6166
})
6267
);
6368
}
69+
}
6470

65-
try {
66-
countDownLatch.await();
67-
} catch (InterruptedException e) {
68-
throw new RuntimeException(e);
69-
}
70-
71-
listener.onResponse(inferenceResolutionBuilder.build());
71+
private static String planInferenceId(InferencePlan plan) {
72+
return plan.inferenceId().fold(FoldContext.small()).toString();
7273
}
7374

7475
public void doInference(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {

Diff for: x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/RerankOperator.java

+8-8
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public class RerankOperator extends AsyncOperator<Page> {
3232
private static final int MAX_INFERENCE_WORKER = 10;
3333

3434
public record Factory(
35-
InferenceService inferenceService,
35+
InferenceRunner inferenceRunner,
3636
String inferenceId,
3737
String queryText,
3838
ExpressionEvaluator.Factory rowEncoderFactory,
@@ -48,7 +48,7 @@ public String describe() {
4848
public Operator get(DriverContext driverContext) {
4949
return new RerankOperator(
5050
driverContext,
51-
inferenceService,
51+
inferenceRunner,
5252
inferenceId,
5353
queryText,
5454
rowEncoderFactory().get(driverContext),
@@ -57,7 +57,7 @@ public Operator get(DriverContext driverContext) {
5757
}
5858
}
5959

60-
private final InferenceService inferenceService;
60+
private final InferenceRunner inferenceRunner;
6161
private final BlockFactory blockFactory;
6262
private final String inferenceId;
6363
private final String queryText;
@@ -66,18 +66,18 @@ public Operator get(DriverContext driverContext) {
6666

6767
public RerankOperator(
6868
DriverContext driverContext,
69-
InferenceService inferenceService,
69+
InferenceRunner inferenceRunner,
7070
String inferenceId,
7171
String queryText,
7272
ExpressionEvaluator rowEncoder,
7373
int scoreChannel
7474
) {
75-
super(driverContext, inferenceService.getThreadContext(), MAX_INFERENCE_WORKER);
75+
super(driverContext, inferenceRunner.getThreadContext(), MAX_INFERENCE_WORKER);
7676

77-
assert inferenceService.getThreadContext() != null;
77+
assert inferenceRunner.getThreadContext() != null;
7878

7979
this.blockFactory = driverContext.blockFactory();
80-
this.inferenceService = inferenceService;
80+
this.inferenceRunner = inferenceRunner;
8181
this.inferenceId = inferenceId;
8282
this.queryText = queryText;
8383
this.rowEncoder = rowEncoder;
@@ -90,7 +90,7 @@ protected void performAsync(Page inputPage, ActionListener<Page> listener) {
9090
final ActionListener<Page> outputListener = ActionListener.runAfter(listener, () -> { releasePageOnAnyThread(inputPage); });
9191

9292
try {
93-
inferenceService.doInference(
93+
inferenceRunner.doInference(
9494
buildInferenceRequest(inputPage),
9595
ActionListener.wrap(
9696
inferenceResponse -> outputListener.onResponse(buildOutput(inputPage, inferenceResponse)),

Diff for: x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
8282
import org.elasticsearch.xpack.esql.evaluator.command.GrokEvaluatorExtracter;
8383
import org.elasticsearch.xpack.esql.expression.Order;
84-
import org.elasticsearch.xpack.esql.inference.InferenceService;
84+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
8585
import org.elasticsearch.xpack.esql.inference.RerankOperator;
8686
import org.elasticsearch.xpack.esql.inference.XContentRowEncoder;
8787
import org.elasticsearch.xpack.esql.plan.logical.Fork;
@@ -152,7 +152,7 @@ public class LocalExecutionPlanner {
152152
private final Supplier<ExchangeSink> exchangeSinkSupplier;
153153
private final EnrichLookupService enrichLookupService;
154154
private final LookupFromIndexService lookupFromIndexService;
155-
private final InferenceService inferenceService;
155+
private final InferenceRunner inferenceRunner;
156156
private final PhysicalOperationProviders physicalOperationProviders;
157157
private final List<ShardContext> shardContexts;
158158

@@ -168,7 +168,7 @@ public LocalExecutionPlanner(
168168
Supplier<ExchangeSink> exchangeSinkSupplier,
169169
EnrichLookupService enrichLookupService,
170170
LookupFromIndexService lookupFromIndexService,
171-
InferenceService inferenceService,
171+
InferenceRunner inferenceRunner,
172172
PhysicalOperationProviders physicalOperationProviders,
173173
List<ShardContext> shardContexts
174174
) {
@@ -184,7 +184,7 @@ public LocalExecutionPlanner(
184184
this.exchangeSinkSupplier = exchangeSinkSupplier;
185185
this.enrichLookupService = enrichLookupService;
186186
this.lookupFromIndexService = lookupFromIndexService;
187-
this.inferenceService = inferenceService;
187+
this.inferenceRunner = inferenceRunner;
188188
this.physicalOperationProviders = physicalOperationProviders;
189189
this.shardContexts = shardContexts;
190190
}
@@ -581,7 +581,7 @@ private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerCon
581581
int scoreChannel = outputLayout.get(rerank.scoreAttribute().id()).channel();
582582

583583
return source.with(
584-
new RerankOperator.Factory(inferenceService, inferenceId, queryText, rowEncoderFactory, scoreChannel),
584+
new RerankOperator.Factory(inferenceRunner, inferenceId, queryText, rowEncoderFactory, scoreChannel),
585585
outputLayout
586586
);
587587
}

Diff for: x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java

+9-13
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
4747
import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
4848
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
49-
import org.elasticsearch.xpack.esql.inference.InferenceService;
49+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
5050
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec;
5151
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
5252
import org.elasticsearch.xpack.esql.plan.physical.OutputExec;
@@ -123,7 +123,7 @@ public class ComputeService {
123123
private final DriverTaskRunner driverRunner;
124124
private final EnrichLookupService enrichLookupService;
125125
private final LookupFromIndexService lookupFromIndexService;
126-
private final InferenceService inferenceService;
126+
private final InferenceRunner inferenceRunner;
127127
private final ClusterService clusterService;
128128
private final AtomicLong childSessionIdGenerator = new AtomicLong();
129129
private final DataNodeComputeHandler dataNodeComputeHandler;
@@ -132,27 +132,24 @@ public class ComputeService {
132132

133133
@SuppressWarnings("this-escape")
134134
public ComputeService(
135-
SearchService searchService,
136-
TransportService transportService,
137-
ExchangeService exchangeService,
135+
TransportActionServices transportActionServices,
138136
EnrichLookupService enrichLookupService,
139137
LookupFromIndexService lookupFromIndexService,
140-
InferenceService inferenceService,
141-
ClusterService clusterService,
142138
ThreadPool threadPool,
143139
BigArrays bigArrays,
144140
BlockFactory blockFactory
145141
) {
146-
this.searchService = searchService;
147-
this.transportService = transportService;
142+
this.searchService = transportActionServices.searchService();
143+
this.transportService = transportActionServices.transportService();
144+
this.exchangeService = transportActionServices.exchangeService();
148145
this.bigArrays = bigArrays.withCircuitBreaking();
149146
this.blockFactory = blockFactory;
150147
var esqlExecutor = threadPool.executor(ThreadPool.Names.SEARCH);
151148
this.driverRunner = new DriverTaskRunner(transportService, esqlExecutor);
152149
this.enrichLookupService = enrichLookupService;
153150
this.lookupFromIndexService = lookupFromIndexService;
154-
this.inferenceService = inferenceService;
155-
this.clusterService = clusterService;
151+
this.inferenceRunner = transportActionServices.inferenceRunner();
152+
this.clusterService = transportActionServices.clusterService();
156153
this.dataNodeComputeHandler = new DataNodeComputeHandler(this, searchService, transportService, exchangeService, esqlExecutor);
157154
this.clusterComputeHandler = new ClusterComputeHandler(
158155
this,
@@ -161,7 +158,6 @@ public ComputeService(
161158
esqlExecutor,
162159
dataNodeComputeHandler
163160
);
164-
this.exchangeService = exchangeService;
165161
}
166162

167163
public void execute(
@@ -414,7 +410,7 @@ public SourceProvider createSourceProvider() {
414410
context.exchangeSinkSupplier(),
415411
enrichLookupService,
416412
lookupFromIndexService,
417-
inferenceService,
413+
inferenceRunner,
418414
new EsPhysicalOperationProviders(context.foldCtx(), contexts, searchService.getIndicesService().getAnalysis()),
419415
contexts
420416
);

Diff for: x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import org.elasticsearch.search.SearchService;
1414
import org.elasticsearch.transport.TransportService;
1515
import org.elasticsearch.usage.UsageService;
16-
import org.elasticsearch.xpack.esql.inference.InferenceService;
16+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
1717

1818
public record TransportActionServices(
1919
TransportService transportService,
@@ -22,5 +22,5 @@ public record TransportActionServices(
2222
ClusterService clusterService,
2323
IndexNameExpressionResolver indexNameExpressionResolver,
2424
UsageService usageService,
25-
InferenceService inferenceService
25+
InferenceRunner inferenceRunner
2626
) {}

Diff for: x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java

+13-16
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver;
5050
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
5151
import org.elasticsearch.xpack.esql.execution.PlanExecutor;
52-
import org.elasticsearch.xpack.esql.inference.InferenceService;
52+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
5353
import org.elasticsearch.xpack.esql.session.Configuration;
5454
import org.elasticsearch.xpack.esql.session.EsqlSession.PlanRunner;
5555
import org.elasticsearch.xpack.esql.session.Result;
@@ -82,7 +82,6 @@ public class TransportEsqlQueryAction extends HandledTransportAction<EsqlQueryRe
8282
private final AsyncTaskManagementService<EsqlQueryRequest, EsqlQueryResponse, EsqlQueryTask> asyncTaskManagementService;
8383
private final RemoteClusterService remoteClusterService;
8484
private final UsageService usageService;
85-
private final InferenceService inferenceService;
8685
private final TransportActionServices services;
8786
private volatile boolean defaultAllowPartialResults;
8887

@@ -128,19 +127,7 @@ public TransportEsqlQueryAction(
128127
bigArrays,
129128
blockFactoryProvider.blockFactory()
130129
);
131-
this.inferenceService = new InferenceService(client);
132-
this.computeService = new ComputeService(
133-
searchService,
134-
transportService,
135-
exchangeService,
136-
enrichLookupService,
137-
lookupFromIndexService,
138-
inferenceService,
139-
clusterService,
140-
threadPool,
141-
bigArrays,
142-
blockFactoryProvider.blockFactory()
143-
);
130+
144131
this.asyncTaskManagementService = new AsyncTaskManagementService<>(
145132
XPackPlugin.ASYNC_RESULTS_INDEX,
146133
client,
@@ -164,8 +151,18 @@ public TransportEsqlQueryAction(
164151
clusterService,
165152
indexNameExpressionResolver,
166153
usageService,
167-
inferenceService
154+
new InferenceRunner(client)
168155
);
156+
157+
this.computeService = new ComputeService(
158+
services,
159+
enrichLookupService,
160+
lookupFromIndexService,
161+
threadPool,
162+
bigArrays,
163+
blockFactoryProvider.blockFactory()
164+
);
165+
169166
defaultAllowPartialResults = EsqlPlugin.QUERY_ALLOW_PARTIAL_RESULTS.get(clusterService.getSettings());
170167
clusterService.getClusterSettings()
171168
.addSettingsUpdateConsumer(EsqlPlugin.QUERY_ALLOW_PARTIAL_RESULTS, v -> defaultAllowPartialResults = v);

0 commit comments

Comments
 (0)