Skip to content

Commit c722ceb

Browse files
authored
Fix score count validation in reranker response (#111212)
* Fix rerank score validation * Update docs/changelog/111212.yaml * Add test case for invalid document indices in reranker result * Preemptive top_n config check * Reorg code + refine tests * Add support for Google Vertex AI task settings * Spotless * Make top N eval async * Update test * Fix broken unit test * Clean up tests * Spotless * Add size check + compare against rankWindowSize * Fix import
1 parent a4e6cf9 commit c722ceb

File tree

5 files changed

+188
-54
lines changed

5 files changed

+188
-54
lines changed

docs/changelog/111212.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 111212
2+
summary: Fix score count validation in reranker response
3+
area: Ranking
4+
type: bug
5+
issues:
6+
- 111202

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
import org.elasticsearch.inference.TaskType;
1515
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
1616
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
17+
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
1718
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1819
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
20+
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
21+
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
1922

2023
import java.util.Arrays;
2124
import java.util.Comparator;
@@ -53,24 +56,77 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext(
5356
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
5457
// Wrap the provided rankListener to an ActionListener that would handle the response from the inference service
5558
// and then pass the results
56-
final ActionListener<InferenceAction.Response> actionListener = scoreListener.delegateFailureAndWrap((l, r) -> {
57-
float[] scores = extractScoresFromResponse(r);
58-
if (scores.length != featureDocs.length) {
59+
final ActionListener<InferenceAction.Response> inferenceListener = scoreListener.delegateFailureAndWrap((l, r) -> {
60+
InferenceServiceResults results = r.getResults();
61+
assert results instanceof RankedDocsResults;
62+
63+
// Ensure we get exactly as many scores as the number of docs we passed, otherwise we may return incorrect results
64+
List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs();
65+
if (rankedDocs.size() != featureDocs.length) {
5966
l.onFailure(
60-
new IllegalStateException("Document and score count mismatch: [" + featureDocs.length + "] vs [" + scores.length + "]")
67+
new IllegalStateException(
68+
"Reranker input document count and returned score count mismatch: ["
69+
+ featureDocs.length
70+
+ "] vs ["
71+
+ rankedDocs.size()
72+
+ "]"
73+
)
6174
);
6275
} else {
76+
float[] scores = extractScoresFromRankedDocs(rankedDocs);
6377
l.onResponse(scores);
6478
}
6579
});
6680

67-
List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
68-
InferenceAction.Request request = generateRequest(featureData);
69-
try {
70-
client.execute(InferenceAction.INSTANCE, request, actionListener);
71-
} finally {
72-
request.decRef();
73-
}
81+
// top N listener
82+
ActionListener<GetInferenceModelAction.Response> topNListener = scoreListener.delegateFailureAndWrap((l, r) -> {
83+
// The rerank inference endpoint may have an override to return top N documents only, in that case let's fail fast to avoid
84+
// assigning scores to the wrong input
85+
Integer configuredTopN = null;
86+
if (r.getEndpoints().isEmpty() == false
87+
&& r.getEndpoints().get(0).getTaskSettings() instanceof CohereRerankTaskSettings cohereTaskSettings) {
88+
configuredTopN = cohereTaskSettings.getTopNDocumentsOnly();
89+
} else if (r.getEndpoints().isEmpty() == false
90+
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
91+
configuredTopN = googleVertexAiTaskSettings.topN();
92+
}
93+
if (configuredTopN != null && configuredTopN < rankWindowSize) {
94+
l.onFailure(
95+
new IllegalArgumentException(
96+
"Inference endpoint ["
97+
+ inferenceId
98+
+ "] is configured to return the top ["
99+
+ configuredTopN
100+
+ "] results, but rank_window_size is ["
101+
+ rankWindowSize
102+
+ "]. Reduce rank_window_size to be less than or equal to the configured top N value."
103+
)
104+
);
105+
return;
106+
}
107+
List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
108+
InferenceAction.Request inferenceRequest = generateRequest(featureData);
109+
try {
110+
client.execute(InferenceAction.INSTANCE, inferenceRequest, inferenceListener);
111+
} finally {
112+
inferenceRequest.decRef();
113+
}
114+
});
115+
116+
GetInferenceModelAction.Request getModelRequest = new GetInferenceModelAction.Request(inferenceId, TaskType.RERANK);
117+
client.execute(GetInferenceModelAction.INSTANCE, getModelRequest, topNListener);
118+
}
119+
120+
/**
121+
* Sorts documents by score descending and discards those with a score less than minScore.
122+
* @param originalDocs documents to process
123+
*/
124+
@Override
125+
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
126+
return Arrays.stream(originalDocs)
127+
.filter(doc -> minScore == null || doc.score >= minScore)
128+
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
129+
.toArray(RankFeatureDoc[]::new);
74130
}
75131

76132
protected InferenceAction.Request generateRequest(List<String> docFeatures) {
@@ -85,28 +141,12 @@ protected InferenceAction.Request generateRequest(List<String> docFeatures) {
85141
);
86142
}
87143

88-
private float[] extractScoresFromResponse(InferenceAction.Response response) {
89-
InferenceServiceResults results = response.getResults();
90-
assert results instanceof RankedDocsResults;
91-
92-
List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs();
144+
private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> rankedDocs) {
93145
float[] scores = new float[rankedDocs.size()];
94146
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
95147
scores[rankedDoc.index()] = rankedDoc.relevanceScore();
96148
}
97149

98150
return scores;
99151
}
100-
101-
/**
102-
* Sorts documents by score descending and discards those with a score less than minScore.
103-
* @param originalDocs documents to process
104-
*/
105-
@Override
106-
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
107-
return Arrays.stream(originalDocs)
108-
.filter(doc -> minScore == null || doc.score >= minScore)
109-
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
110-
.toArray(RankFeatureDoc[]::new);
111-
}
112152
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import org.elasticsearch.inference.TaskType;
1313
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
1414
import org.elasticsearch.test.ESTestCase;
15-
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
15+
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
1616

1717
import static org.mockito.ArgumentMatchers.any;
1818
import static org.mockito.ArgumentMatchers.argThat;
@@ -54,10 +54,9 @@ public void onFailure(Exception e) {
5454
fail();
5555
}
5656
});
57-
5857
verify(mockClient).execute(
59-
eq(InferenceAction.INSTANCE),
60-
argThat(actionRequest -> ((InferenceAction.Request) actionRequest).getTaskType().equals(TaskType.RERANK)),
58+
eq(GetInferenceModelAction.INSTANCE),
59+
argThat(actionRequest -> ((GetInferenceModelAction.Request) actionRequest).getTaskType().equals(TaskType.RERANK)),
6160
any()
6261
);
6362
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.rank.textsimilarity;
99

10+
import org.elasticsearch.action.search.SearchPhaseExecutionException;
1011
import org.elasticsearch.client.internal.Client;
1112
import org.elasticsearch.index.query.QueryBuilders;
1213
import org.elasticsearch.inference.InputType;
@@ -29,22 +30,46 @@
2930
import java.util.Objects;
3031

3132
import static org.hamcrest.Matchers.containsString;
33+
import static org.hamcrest.Matchers.equalTo;
3234

3335
public class TextSimilarityRankTests extends ESSingleNodeTestCase {
3436

3537
/**
36-
* {@code TextSimilarityRankBuilder} that simulates an inference call that returns a different number of results as the input.
38+
* {@code TextSimilarityRankBuilder} that sets top_n in the inference endpoint's task settings.
39+
* See {@code TextSimilarityTestPlugin -> TestFilter -> handleGetInferenceModelActionRequest} for the logic that extracts the top_n
40+
* value.
3741
*/
38-
public static class InvalidInferenceResultCountProvidingTextSimilarityRankBuilder extends TextSimilarityRankBuilder {
42+
public static class TopNConfigurationAcceptingTextSimilarityRankBuilder extends TextSimilarityRankBuilder {
3943

40-
public InvalidInferenceResultCountProvidingTextSimilarityRankBuilder(
44+
public TopNConfigurationAcceptingTextSimilarityRankBuilder(
4145
String field,
4246
String inferenceId,
4347
String inferenceText,
4448
int rankWindowSize,
45-
Float minScore
49+
Float minScore,
50+
int topN
51+
) {
52+
super(field, inferenceId + "-task-settings-top-" + topN, inferenceText, rankWindowSize, minScore);
53+
}
54+
}
55+
56+
/**
57+
* {@code TextSimilarityRankBuilder} that simulates an inference call returning N results.
58+
*/
59+
public static class InferenceResultCountAcceptingTextSimilarityRankBuilder extends TextSimilarityRankBuilder {
60+
61+
private final int inferenceResultCount;
62+
63+
public InferenceResultCountAcceptingTextSimilarityRankBuilder(
64+
String field,
65+
String inferenceId,
66+
String inferenceText,
67+
int rankWindowSize,
68+
Float minScore,
69+
int inferenceResultCount
4670
) {
4771
super(field, inferenceId, inferenceText, rankWindowSize, minScore);
72+
this.inferenceResultCount = inferenceResultCount;
4873
}
4974

5075
@Override
@@ -62,10 +87,10 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
6287
protected InferenceAction.Request generateRequest(List<String> docFeatures) {
6388
return new InferenceAction.Request(
6489
TaskType.RERANK,
65-
inferenceId,
90+
this.inferenceId,
6691
inferenceText,
6792
docFeatures,
68-
Map.of("invalidInferenceResultCount", true),
93+
Map.of("inferenceResultCount", inferenceResultCount),
6994
InputType.SEARCH,
7095
InferenceAction.Request.DEFAULT_TIMEOUT
7196
);
@@ -151,17 +176,38 @@ public void testRerankInferenceFailure() {
151176
);
152177
}
153178

154-
public void testRerankInferenceResultMismatch() {
155-
ElasticsearchAssertions.assertFailures(
179+
public void testRerankTopNConfigurationAndRankWindowSizeMismatch() {
180+
SearchPhaseExecutionException ex = expectThrows(
181+
SearchPhaseExecutionException.class,
156182
// Execute search with text similarity reranking
157183
client.prepareSearch()
158184
.setRankBuilder(
159-
new InvalidInferenceResultCountProvidingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f)
185+
// Simulate reranker configuration with top_n=3 in task_settings, which is different from rank_window_size=10
186+
// (Note: top_n comes from inferenceId, there's no other easy way of passing this to the mocked get model request)
187+
new TopNConfigurationAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, 3)
160188
)
161-
.setQuery(QueryBuilders.matchAllQuery()),
162-
RestStatus.INTERNAL_SERVER_ERROR,
163-
containsString("Failed to execute phase [rank-feature], Computing updated ranks for results failed")
189+
.setQuery(QueryBuilders.matchAllQuery())
190+
);
191+
assertThat(ex.status(), equalTo(RestStatus.BAD_REQUEST));
192+
assertThat(
193+
ex.getDetailedMessage(),
194+
containsString("Reduce rank_window_size to be less than or equal to the configured top N value")
195+
);
196+
}
197+
198+
public void testRerankInputSizeAndInferenceResultsMismatch() {
199+
SearchPhaseExecutionException ex = expectThrows(
200+
SearchPhaseExecutionException.class,
201+
// Execute search with text similarity reranking
202+
client.prepareSearch()
203+
.setRankBuilder(
204+
// Simulate reranker returning different number of results from input
205+
new InferenceResultCountAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, 4)
206+
)
207+
.setQuery(QueryBuilders.matchAllQuery())
164208
);
209+
assertThat(ex.status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
210+
assertThat(ex.getDetailedMessage(), containsString("Reranker input document count and returned score count mismatch"));
165211
}
166212

167213
private static void assertHitHasRankScoreAndText(SearchHit hit, int expectedRank, float expectedScore, String expectedText) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
2222
import org.elasticsearch.common.io.stream.StreamInput;
2323
import org.elasticsearch.common.io.stream.StreamOutput;
24+
import org.elasticsearch.inference.EmptyTaskSettings;
2425
import org.elasticsearch.inference.InputType;
26+
import org.elasticsearch.inference.ModelConfigurations;
2527
import org.elasticsearch.inference.TaskType;
2628
import org.elasticsearch.plugins.ActionPlugin;
2729
import org.elasticsearch.plugins.Plugin;
@@ -39,15 +41,21 @@
3941
import org.elasticsearch.xcontent.ParseField;
4042
import org.elasticsearch.xcontent.ToXContent;
4143
import org.elasticsearch.xcontent.XContentBuilder;
44+
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
4245
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
4346
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
47+
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
48+
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
49+
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
4450

4551
import java.io.IOException;
4652
import java.util.ArrayList;
4753
import java.util.Collection;
4854
import java.util.Collections;
4955
import java.util.List;
5056
import java.util.Map;
57+
import java.util.regex.Matcher;
58+
import java.util.regex.Pattern;
5159

5260
import static java.util.Collections.singletonList;
5361
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
@@ -100,31 +108,66 @@ public int order() {
100108
}
101109

102110
@Override
103-
@SuppressWarnings("unchecked")
104111
public <Request extends ActionRequest, Response extends ActionResponse> void apply(
105112
Task task,
106113
String action,
107114
Request request,
108115
ActionListener<Response> listener,
109116
ActionFilterChain<Request, Response> chain
110117
) {
111-
// For any other action than inference, execute normally
112-
if (action.equals(InferenceAction.INSTANCE.name()) == false) {
118+
if (action.equals(GetInferenceModelAction.INSTANCE.name())) {
119+
assert request instanceof GetInferenceModelAction.Request;
120+
handleGetInferenceModelActionRequest((GetInferenceModelAction.Request) request, listener);
121+
} else if (action.equals(InferenceAction.INSTANCE.name())) {
122+
assert request instanceof InferenceAction.Request;
123+
handleInferenceActionRequest((InferenceAction.Request) request, listener);
124+
} else {
125+
// For any other action than get model and inference, execute normally
113126
chain.proceed(task, action, request, listener);
114-
return;
115127
}
128+
}
116129

117-
assert request instanceof InferenceAction.Request;
118-
boolean shouldThrow = (boolean) ((InferenceAction.Request) request).getTaskSettings().getOrDefault("throwing", false);
119-
boolean hasInvalidInferenceResultCount = (boolean) ((InferenceAction.Request) request).getTaskSettings()
120-
.getOrDefault("invalidInferenceResultCount", false);
130+
@SuppressWarnings("unchecked")
131+
private <Response extends ActionResponse> void handleGetInferenceModelActionRequest(
132+
GetInferenceModelAction.Request request,
133+
ActionListener<Response> listener
134+
) {
135+
String inferenceEntityId = request.getInferenceEntityId();
136+
Integer topN = null;
137+
Matcher extractTopN = Pattern.compile(".*(task-settings-top-\\d+).*").matcher(inferenceEntityId);
138+
if (extractTopN.find()) {
139+
topN = Integer.parseInt(extractTopN.group(1).replaceAll("\\D", ""));
140+
}
141+
142+
ActionResponse response = new GetInferenceModelAction.Response(
143+
List.of(
144+
new ModelConfigurations(
145+
request.getInferenceEntityId(),
146+
request.getTaskType(),
147+
CohereService.NAME,
148+
new CohereRerankServiceSettings("uri", "model", null),
149+
topN == null ? new EmptyTaskSettings() : new CohereRerankTaskSettings(topN, null, null)
150+
)
151+
)
152+
);
153+
listener.onResponse((Response) response);
154+
}
155+
156+
@SuppressWarnings("unchecked")
157+
private <Response extends ActionResponse> void handleInferenceActionRequest(
158+
InferenceAction.Request request,
159+
ActionListener<Response> listener
160+
) {
161+
Map<String, Object> taskSettings = request.getTaskSettings();
162+
boolean shouldThrow = (boolean) taskSettings.getOrDefault("throwing", false);
163+
Integer inferenceResultCount = (Integer) taskSettings.get("inferenceResultCount");
121164

122165
if (shouldThrow) {
123166
listener.onFailure(new UnsupportedOperationException("simulated failure"));
124167
} else {
125168
List<RankedDocsResults.RankedDoc> rankedDocsResults = new ArrayList<>();
126-
List<String> inputs = ((InferenceAction.Request) request).getInput();
127-
int resultCount = hasInvalidInferenceResultCount ? inputs.size() - 1 : inputs.size();
169+
List<String> inputs = request.getInput();
170+
int resultCount = inferenceResultCount == null ? inputs.size() : inferenceResultCount;
128171
for (int i = 0; i < resultCount; i++) {
129172
rankedDocsResults.add(new RankedDocsResults.RankedDoc(i, Float.parseFloat(inputs.get(i)), inputs.get(i)));
130173
}

0 commit comments

Comments
 (0)