Skip to content

Commit 785dfde

Browse files
timgreinmatthewabbott
authored andcommitted
[Inference API] Propagate infer trace context to EIS (elastic#113407)
1 parent d26165c commit 785dfde

File tree

7 files changed

+99
-12
lines changed

7 files changed

+99
-12
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreator.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1414
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1515
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
16+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
1617

1718
import java.util.Objects;
1819

@@ -24,14 +25,17 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer
2425

2526
private final ServiceComponents serviceComponents;
2627

27-
public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents) {
28+
private final TraceContext traceContext;
29+
30+
public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents, TraceContext traceContext) {
2831
this.sender = Objects.requireNonNull(sender);
2932
this.serviceComponents = Objects.requireNonNull(serviceComponents);
33+
this.traceContext = traceContext;
3034
}
3135

3236
@Override
3337
public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) {
34-
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents);
38+
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext);
3539
var errorMessage = constructFailedToSendRequestMessage(model.uri(), "Elastic Inference Service sparse embeddings");
3640
return new SenderExecutableAction(sender, requestManager, errorMessage);
3741
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceSparseEmbeddingsResponseEntity;
2020
import org.elasticsearch.xpack.inference.services.ServiceComponents;
2121
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
22+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
2223

2324
import java.util.List;
2425
import java.util.function.Supplier;
@@ -35,6 +36,8 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast
3536

3637
private final Truncator truncator;
3738

39+
private final TraceContext traceContext;
40+
3841
private static ResponseHandler createSparseEmbeddingsHandler() {
3942
return new ElasticInferenceServiceResponseHandler(
4043
"Elastic Inference Service sparse embeddings",
@@ -44,11 +47,13 @@ private static ResponseHandler createSparseEmbeddingsHandler() {
4447

4548
public ElasticInferenceServiceSparseEmbeddingsRequestManager(
4649
ElasticInferenceServiceSparseEmbeddingsModel model,
47-
ServiceComponents serviceComponents
50+
ServiceComponents serviceComponents,
51+
TraceContext traceContext
4852
) {
4953
super(serviceComponents.threadPool(), model);
5054
this.model = model;
5155
this.truncator = serviceComponents.truncator();
56+
this.traceContext = traceContext;
5257
}
5358

5459
@Override
@@ -64,7 +69,8 @@ public void execute(
6469
ElasticInferenceServiceSparseEmbeddingsRequest request = new ElasticInferenceServiceSparseEmbeddingsRequest(
6570
truncator,
6671
truncatedInput,
67-
model
72+
model,
73+
traceContext
6874
);
6975
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
7076
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
import org.apache.http.entity.ByteArrayEntity;
1313
import org.apache.http.message.BasicHeader;
1414
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.tasks.Task;
1516
import org.elasticsearch.xcontent.XContentType;
1617
import org.elasticsearch.xpack.inference.common.Truncator;
1718
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1819
import org.elasticsearch.xpack.inference.external.request.Request;
1920
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
2022

2123
import java.net.URI;
2224
import java.nio.charset.StandardCharsets;
@@ -31,15 +33,19 @@ public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticIn
3133
private final Truncator.TruncationResult truncationResult;
3234
private final Truncator truncator;
3335

36+
private final TraceContext traceContext;
37+
3438
public ElasticInferenceServiceSparseEmbeddingsRequest(
3539
Truncator truncator,
3640
Truncator.TruncationResult truncationResult,
37-
ElasticInferenceServiceSparseEmbeddingsModel model
41+
ElasticInferenceServiceSparseEmbeddingsModel model,
42+
TraceContext traceContext
3843
) {
3944
this.truncator = truncator;
4045
this.truncationResult = truncationResult;
4146
this.model = Objects.requireNonNull(model);
4247
this.uri = model.uri();
48+
this.traceContext = traceContext;
4349
}
4450

4551
@Override
@@ -50,6 +56,10 @@ public HttpRequest createHttpRequest() {
5056
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
5157
httpPost.setEntity(byteEntity);
5258

59+
if (traceContext != null) {
60+
propagateTraceContext(httpPost);
61+
}
62+
5363
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
5464

5565
return new HttpRequest(httpPost, getInferenceEntityId());
@@ -65,16 +75,32 @@ public URI getURI() {
6575
return this.uri;
6676
}
6777

78+
public TraceContext getTraceContext() {
79+
return traceContext;
80+
}
81+
6882
@Override
6983
public Request truncate() {
7084
var truncatedInput = truncator.truncate(truncationResult.input());
7185

72-
return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model);
86+
return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContext);
7387
}
7488

7589
@Override
7690
public boolean[] getTruncationInfo() {
7791
return truncationResult.truncated().clone();
7892
}
7993

94+
private void propagateTraceContext(HttpPost httpPost) {
95+
var traceParent = traceContext.traceParent();
96+
var traceState = traceContext.traceState();
97+
98+
if (traceParent != null) {
99+
httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent);
100+
}
101+
102+
if (traceState != null) {
103+
httpPost.setHeader(Task.TRACE_STATE, traceState);
104+
}
105+
}
80106
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.inference.ModelSecrets;
2424
import org.elasticsearch.inference.TaskType;
2525
import org.elasticsearch.rest.RestStatus;
26+
import org.elasticsearch.tasks.Task;
2627
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
2728
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
2829
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
@@ -34,6 +35,7 @@
3435
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
3536
import org.elasticsearch.xpack.inference.services.SenderService;
3637
import org.elasticsearch.xpack.inference.services.ServiceComponents;
38+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
3739

3840
import java.util.List;
3941
import java.util.Map;
@@ -75,8 +77,13 @@ protected void doInfer(
7577
return;
7678
}
7779

80+
// We extract the trace context here as it's sufficient to propagate the trace information of the REST request,
81+
// which handles the request to the inference API overall (including the outgoing request, which is started in a new thread
82+
// generating a different "traceparent" as every task and every REST request creates a new span).
83+
var currentTraceInfo = getCurrentTraceInfo();
84+
7885
ElasticInferenceServiceModel elasticInferenceServiceModel = (ElasticInferenceServiceModel) model;
79-
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents());
86+
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), currentTraceInfo);
8087

8188
var action = elasticInferenceServiceModel.accept(actionCreator, taskSettings);
8289
action.execute(inputs, timeout, listener);
@@ -258,4 +265,13 @@ private ElasticInferenceServiceSparseEmbeddingsModel updateModelWithEmbeddingDet
258265

259266
return new ElasticInferenceServiceSparseEmbeddingsModel(model, serviceSettings);
260267
}
268+
269+
private TraceContext getCurrentTraceInfo() {
270+
var threadPool = getServiceComponents().threadPool();
271+
272+
var traceParent = threadPool.getThreadContext().getHeader(Task.TRACE_PARENT);
273+
var traceState = threadPool.getThreadContext().getHeader(Task.TRACE_STATE);
274+
275+
return new TraceContext(traceParent, traceState);
276+
}
261277
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.telemetry;
9+
10+
public record TraceContext(String traceParent, String traceState) {}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
2626
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
2727
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests;
28+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
2829
import org.junit.After;
2930
import org.junit.Before;
3031

@@ -89,7 +90,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce
8990
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
9091

9192
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
92-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool));
93+
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
9394
var action = actionCreator.create(model);
9495

9596
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -145,7 +146,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx
145146
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
146147

147148
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
148-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool));
149+
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
149150
var action = actionCreator.create(model);
150151

151152
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -197,7 +198,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc
197198
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
198199

199200
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
200-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool));
201+
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
201202
var action = actionCreator.create(model);
202203

203204
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -257,7 +258,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
257258

258259
// truncated to 1 token = 3 characters
259260
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), 1);
260-
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool));
261+
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
261262
var action = actionCreator.create(model);
262263

263264
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -286,4 +287,8 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException {
286287
}
287288
}
288289

290+
private TraceContext createTraceContext() {
291+
return new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10));
292+
}
293+
289294
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99

1010
import org.apache.http.HttpHeaders;
1111
import org.apache.http.client.methods.HttpPost;
12+
import org.elasticsearch.tasks.Task;
1213
import org.elasticsearch.test.ESTestCase;
1314
import org.elasticsearch.xcontent.XContentType;
1415
import org.elasticsearch.xpack.inference.common.Truncator;
1516
import org.elasticsearch.xpack.inference.common.TruncatorTests;
1617
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests;
18+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
1719

1820
import java.io.IOException;
1921
import java.util.List;
@@ -42,6 +44,23 @@ public void testCreateHttpRequest() throws IOException {
4244
assertThat(requestMap.get("input"), is(List.of(input)));
4345
}
4446

47+
public void testTraceContextPropagatedThroughHTTPHeaders() {
48+
var url = "http://eis-gateway.com";
49+
var input = "input";
50+
51+
var request = createRequest(url, input);
52+
var httpRequest = request.createHttpRequest();
53+
54+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
55+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
56+
57+
var traceParent = request.getTraceContext().traceParent();
58+
var traceState = request.getTraceContext().traceState();
59+
60+
assertThat(httpPost.getLastHeader(Task.TRACE_PARENT_HTTP_HEADER).getValue(), is(traceParent));
61+
assertThat(httpPost.getLastHeader(Task.TRACE_STATE).getValue(), is(traceState));
62+
}
63+
4564
public void testTruncate_ReducesInputTextSizeByHalf() throws IOException {
4665
var url = "http://eis-gateway.com";
4766
var input = "abcd";
@@ -75,7 +94,8 @@ public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url,
7594
return new ElasticInferenceServiceSparseEmbeddingsRequest(
7695
TruncatorTests.createTruncator(),
7796
new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
78-
embeddingsModel
97+
embeddingsModel,
98+
new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10))
7999
);
80100
}
81101
}

0 commit comments

Comments
 (0)