Skip to content

Commit d85a3f3

Browse files
[ML] Refactor OpenAI to use ConstructingObjectParser and consolidate class locations (#124380) (#124561)
* Switching openai to ConstructingObjectParser * Moving files * Fixing package errors
1 parent 93a950e commit d85a3f3

File tree

33 files changed

+213
-214
lines changed

33 files changed

+213
-214
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1717
import org.elasticsearch.xpack.inference.external.http.sender.TruncatingRequestManager;
1818
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
19+
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
1920
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
21+
import org.elasticsearch.xpack.inference.external.openai.OpenAiEmbeddingsRequest;
22+
import org.elasticsearch.xpack.inference.external.openai.OpenAiEmbeddingsResponseEntity;
2023
import org.elasticsearch.xpack.inference.external.openai.OpenAiResponseHandler;
21-
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequest;
22-
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest;
23-
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
24-
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;
24+
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionRequest;
2525
import org.elasticsearch.xpack.inference.services.ServiceComponents;
2626
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
2727
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiResponseHandler;
1717
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
1818
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
19+
import org.elasticsearch.xpack.inference.external.openai.OpenAiEmbeddingsResponseEntity;
1920
import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiEmbeddingsRequest;
20-
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;
2121
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
2222

2323
import java.util.List;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler;
1717
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
1818
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
19+
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
1920
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceUnifiedChatCompletionRequest;
20-
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
2121
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
2222
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
2323

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.external.response.openai;
8+
package org.elasticsearch.xpack.inference.external.openai;
99

10-
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
10+
import org.elasticsearch.xcontent.ConstructingObjectParser;
11+
import org.elasticsearch.xcontent.ParseField;
1112
import org.elasticsearch.xcontent.XContentFactory;
12-
import org.elasticsearch.xcontent.XContentParser;
1313
import org.elasticsearch.xcontent.XContentParserConfiguration;
1414
import org.elasticsearch.xcontent.XContentType;
1515
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
@@ -19,14 +19,10 @@
1919
import java.io.IOException;
2020
import java.util.List;
2121

22-
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
23-
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
24-
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
22+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
2523

2624
public class OpenAiChatCompletionResponseEntity {
2725

28-
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in OpenAI chat completions response";
29-
3026
/**
3127
* Parses the OpenAI chat completion response.
3228
* For a request like:
@@ -71,32 +67,51 @@ public class OpenAiChatCompletionResponseEntity {
7167
*/
7268

7369
public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException {
74-
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
75-
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
76-
moveToFirstToken(jsonParser);
77-
78-
XContentParser.Token token = jsonParser.currentToken();
79-
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
80-
81-
positionParserAtTokenAfterField(jsonParser, "choices", FAILED_TO_FIND_FIELD_TEMPLATE);
70+
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
71+
return CompletionResult.PARSER.apply(p, null).toChatCompletionResults();
72+
}
73+
}
8274

83-
jsonParser.nextToken();
84-
ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser);
75+
public record CompletionResult(List<Choice> choices) {
76+
@SuppressWarnings("unchecked")
77+
public static final ConstructingObjectParser<CompletionResult, Void> PARSER = new ConstructingObjectParser<>(
78+
CompletionResult.class.getSimpleName(),
79+
true,
80+
args -> new CompletionResult((List<Choice>) args[0])
81+
);
8582

86-
positionParserAtTokenAfterField(jsonParser, "message", FAILED_TO_FIND_FIELD_TEMPLATE);
83+
static {
84+
PARSER.declareObjectArray(constructorArg(), Choice.PARSER::apply, new ParseField("choices"));
85+
}
8786

88-
token = jsonParser.currentToken();
87+
public ChatCompletionResults toChatCompletionResults() {
88+
return new ChatCompletionResults(
89+
choices.stream().map(choice -> new ChatCompletionResults.Result(choice.message.content)).toList()
90+
);
91+
}
92+
}
8993

90-
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
94+
public record Choice(Message message) {
95+
public static final ConstructingObjectParser<Choice, Void> PARSER = new ConstructingObjectParser<>(
96+
Choice.class.getSimpleName(),
97+
true,
98+
args -> new Choice((Message) args[0])
99+
);
91100

92-
positionParserAtTokenAfterField(jsonParser, "content", FAILED_TO_FIND_FIELD_TEMPLATE);
101+
static {
102+
PARSER.declareObject(constructorArg(), Message.PARSER::apply, new ParseField("message"));
103+
}
104+
}
93105

94-
XContentParser.Token contentToken = jsonParser.currentToken();
95-
ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser);
96-
String content = jsonParser.text();
106+
public record Message(String content) {
107+
public static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(
108+
Message.class.getSimpleName(),
109+
true,
110+
args -> new Message((String) args[0])
111+
);
97112

98-
return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content)));
113+
static {
114+
PARSER.declareString(constructorArg(), new ParseField("content"));
99115
}
100116
}
101-
102117
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.external.request.openai;
8+
package org.elasticsearch.xpack.inference.external.openai;
99

1010
import org.apache.http.HttpHeaders;
1111
import org.apache.http.client.methods.HttpPost;
@@ -21,8 +21,8 @@
2121
import java.nio.charset.StandardCharsets;
2222
import java.util.Objects;
2323

24+
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.createOrgHeader;
2425
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
25-
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader;
2626

2727
public class OpenAiEmbeddingsRequest implements Request {
2828

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.external.request.openai;
8+
package org.elasticsearch.xpack.inference.external.openai;
99

1010
import org.elasticsearch.core.Nullable;
1111
import org.elasticsearch.xcontent.ToXContentObject;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.openai;
9+
10+
import org.elasticsearch.xcontent.ConstructingObjectParser;
11+
import org.elasticsearch.xcontent.ParseField;
12+
import org.elasticsearch.xcontent.XContentFactory;
13+
import org.elasticsearch.xcontent.XContentParserConfiguration;
14+
import org.elasticsearch.xcontent.XContentType;
15+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
16+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
17+
import org.elasticsearch.xpack.inference.external.request.Request;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
22+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
23+
24+
public class OpenAiEmbeddingsResponseEntity {
25+
/**
26+
* Parses the OpenAI json response.
27+
* For a request like:
28+
*
29+
* <pre>
30+
* <code>
31+
* {
32+
* "inputs": ["hello this is my name", "I wish I was there!"]
33+
* }
34+
* </code>
35+
* </pre>
36+
*
37+
* The response would look like:
38+
*
39+
* <pre>
40+
* <code>
41+
* {
42+
* "object": "list",
43+
* "data": [
44+
* {
45+
* "object": "embedding",
46+
* "embedding": [
47+
* -0.009327292,
48+
* .... (1536 floats total for ada-002)
49+
* -0.0028842222,
50+
* ],
51+
* "index": 0
52+
* },
53+
* {
54+
* "object": "embedding",
55+
* "embedding": [ ... ],
56+
* "index": 1
57+
* }
58+
* ],
59+
* "model": "text-embedding-ada-002",
60+
* "usage": {
61+
* "prompt_tokens": 8,
62+
* "total_tokens": 8
63+
* }
64+
* }
65+
* </code>
66+
* </pre>
67+
*/
68+
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
69+
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
70+
return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
71+
}
72+
}
73+
74+
public record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> embeddingResults) {
75+
@SuppressWarnings("unchecked")
76+
public static final ConstructingObjectParser<EmbeddingFloatResult, Void> PARSER = new ConstructingObjectParser<>(
77+
EmbeddingFloatResult.class.getSimpleName(),
78+
true,
79+
args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0])
80+
);
81+
82+
static {
83+
PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data"));
84+
}
85+
86+
public TextEmbeddingFloatResults toTextEmbeddingFloatResults() {
87+
return new TextEmbeddingFloatResults(
88+
embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
89+
);
90+
}
91+
}
92+
93+
public record EmbeddingFloatResultEntry(List<Float> embedding) {
94+
@SuppressWarnings("unchecked")
95+
public static final ConstructingObjectParser<EmbeddingFloatResultEntry, Void> PARSER = new ConstructingObjectParser<>(
96+
EmbeddingFloatResultEntry.class.getSimpleName(),
97+
true,
98+
args -> new EmbeddingFloatResultEntry((List<Float>) args[0])
99+
);
100+
101+
static {
102+
PARSER.declareFloatArray(constructorArg(), new ParseField("embedding"));
103+
}
104+
}
105+
106+
private OpenAiEmbeddingsResponseEntity() {}
107+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.external.request.openai;
8+
package org.elasticsearch.xpack.inference.external.openai;
99

1010
import org.elasticsearch.xpack.inference.external.request.Request;
1111

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.external.request.openai;
8+
package org.elasticsearch.xpack.inference.external.openai;
99

1010
import org.apache.http.HttpHeaders;
1111
import org.apache.http.client.methods.HttpPost;
@@ -21,8 +21,8 @@
2121
import java.nio.charset.StandardCharsets;
2222
import java.util.Objects;
2323

24+
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.createOrgHeader;
2425
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
25-
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader;
2626

2727
public class OpenAiUnifiedChatCompletionRequest implements Request {
2828

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.external.request.openai;
8+
package org.elasticsearch.xpack.inference.external.openai;
99

1010
import org.elasticsearch.common.Strings;
1111
import org.elasticsearch.xcontent.ToXContentObject;
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.external.request.openai;
8+
package org.elasticsearch.xpack.inference.external.openai;
99

1010
import org.apache.http.Header;
1111
import org.apache.http.message.BasicHeader;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioChatCompletionResponseEntity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import org.elasticsearch.xcontent.XContentType;
1616
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
1717
import org.elasticsearch.xpack.inference.external.http.HttpResult;
18+
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
1819
import org.elasticsearch.xpack.inference.external.request.Request;
1920
import org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioChatCompletionRequest;
2021
import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
21-
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
2222

2323
import java.io.IOException;
2424
import java.util.List;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioEmbeddingsResponseEntity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
import org.elasticsearch.inference.InferenceServiceResults;
1111
import org.elasticsearch.xpack.inference.external.http.HttpResult;
12+
import org.elasticsearch.xpack.inference.external.openai.OpenAiEmbeddingsResponseEntity;
1213
import org.elasticsearch.xpack.inference.external.request.Request;
1314
import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
14-
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;
1515

1616
import java.io.IOException;
1717

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/mistral/MistralEmbeddingsResponseEntity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
import org.elasticsearch.inference.InferenceServiceResults;
1111
import org.elasticsearch.xpack.inference.external.http.HttpResult;
12+
import org.elasticsearch.xpack.inference.external.openai.OpenAiEmbeddingsResponseEntity;
1213
import org.elasticsearch.xpack.inference.external.request.Request;
1314
import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
14-
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;
1515

1616
import java.io.IOException;
1717

0 commit comments

Comments
 (0)