Skip to content

Commit 375fc77

Browse files
authored
[ML] update truncation default & adding field output when input is truncated (elastic#79942)
This commit makes the two following changes (along with some refactoring) - Nlp results will now indicate if the input was truncated or not - The default truncation is now `none` instead of `first`
1 parent 8a8d868 commit 375fc77

25 files changed

+429
-113
lines changed

docs/reference/ml/ml-shared.asciidoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,7 @@ end::inference-config-nlp-tokenization-bert-do-lower-case[]
927927

928928
tag::inference-config-nlp-tokenization-bert-truncate[]
929929
Indicates how tokens are truncated when they exceed `max_sequence_length`.
930-
The default value is `first`.
930+
The default value is `none`.
931931
+
932932
--
933933
* `none`: No truncation occurs; the inference request receives an error.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
2424
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
2525
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
26+
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
2627
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
2728
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
2829
import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
@@ -498,7 +499,13 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
498499
new NamedWriteableRegistry.Entry(InferenceResults.class, PyTorchPassThroughResults.NAME, PyTorchPassThroughResults::new)
499500
);
500501
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new));
501-
502+
namedWriteables.add(
503+
new NamedWriteableRegistry.Entry(
504+
InferenceResults.class,
505+
NlpClassificationInferenceResults.NAME,
506+
NlpClassificationInferenceResults::new
507+
)
508+
);
502509
// Inference Configs
503510
namedWriteables.add(
504511
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new)

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,41 +10,27 @@
1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
1212
import org.elasticsearch.xcontent.XContentBuilder;
13-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
1413

1514
import java.io.IOException;
16-
import java.util.LinkedHashMap;
1715
import java.util.List;
1816
import java.util.Map;
1917
import java.util.Objects;
2018

21-
public class FillMaskResults extends ClassificationInferenceResults {
19+
public class FillMaskResults extends NlpClassificationInferenceResults {
2220

2321
public static final String NAME = "fill_mask_result";
2422

2523
private final String predictedSequence;
2624

2725
public FillMaskResults(
28-
double value,
2926
String classificationLabel,
3027
String predictedSequence,
3128
List<TopClassEntry> topClasses,
32-
String topNumClassesField,
3329
String resultsField,
34-
Double predictionProbability
30+
Double predictionProbability,
31+
boolean isTruncated
3532
) {
36-
super(
37-
value,
38-
classificationLabel,
39-
topClasses,
40-
List.of(),
41-
topNumClassesField,
42-
resultsField,
43-
PredictionFieldType.STRING,
44-
0,
45-
predictionProbability,
46-
null
47-
);
33+
super(classificationLabel, topClasses, resultsField, predictionProbability, isTruncated);
4834
this.predictedSequence = predictedSequence;
4935
}
5036

@@ -54,8 +40,8 @@ public FillMaskResults(StreamInput in) throws IOException {
5440
}
5541

5642
@Override
57-
public void writeTo(StreamOutput out) throws IOException {
58-
super.writeTo(out);
43+
public void doWriteTo(StreamOutput out) throws IOException {
44+
super.doWriteTo(out);
5945
out.writeString(predictedSequence);
6046
}
6147

@@ -64,11 +50,9 @@ public String getPredictedSequence() {
6450
}
6551

6652
@Override
67-
public Map<String, Object> asMap() {
68-
Map<String, Object> map = new LinkedHashMap<>();
53+
void addMapFields(Map<String, Object> map) {
54+
super.addMapFields(map);
6955
map.put(resultsField + "_sequence", predictedSequence);
70-
map.putAll(super.asMap());
71-
return map;
7256
}
7357

7458
@Override
@@ -77,8 +61,9 @@ public String getWriteableName() {
7761
}
7862

7963
@Override
80-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
81-
return super.toXContent(builder, params).field(resultsField + "_sequence", predictedSequence);
64+
public void doXContentBody(XContentBuilder builder, Params params) throws IOException {
65+
super.doXContentBody(builder, params);
66+
builder.field(resultsField + "_sequence", predictedSequence);
8267
}
8368

8469
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import java.util.Objects;
2121
import java.util.stream.Collectors;
2222

23-
public class NerResults implements InferenceResults {
23+
public class NerResults extends NlpInferenceResults {
2424

2525
public static final String NAME = "ner_result";
2626
public static final String ENTITY_FIELD = "entities";
@@ -30,27 +30,28 @@ public class NerResults implements InferenceResults {
3030

3131
private final List<EntityGroup> entityGroups;
3232

33-
public NerResults(String resultsField, String annotatedResult, List<EntityGroup> entityGroups) {
33+
public NerResults(String resultsField, String annotatedResult, List<EntityGroup> entityGroups, boolean isTruncated) {
34+
super(isTruncated);
3435
this.entityGroups = Objects.requireNonNull(entityGroups);
3536
this.resultsField = Objects.requireNonNull(resultsField);
3637
this.annotatedResult = Objects.requireNonNull(annotatedResult);
3738
}
3839

3940
public NerResults(StreamInput in) throws IOException {
41+
super(in);
4042
entityGroups = in.readList(EntityGroup::new);
4143
resultsField = in.readString();
4244
annotatedResult = in.readString();
4345
}
4446

4547
@Override
46-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
48+
void doXContentBody(XContentBuilder builder, Params params) throws IOException {
4749
builder.field(resultsField, annotatedResult);
4850
builder.startArray("entities");
4951
for (EntityGroup entity : entityGroups) {
5052
entity.toXContent(builder, params);
5153
}
5254
builder.endArray();
53-
return builder;
5455
}
5556

5657
@Override
@@ -59,18 +60,16 @@ public String getWriteableName() {
5960
}
6061

6162
@Override
62-
public void writeTo(StreamOutput out) throws IOException {
63+
void doWriteTo(StreamOutput out) throws IOException {
6364
out.writeList(entityGroups);
6465
out.writeString(resultsField);
6566
out.writeString(annotatedResult);
6667
}
6768

6869
@Override
69-
public Map<String, Object> asMap() {
70-
Map<String, Object> map = new LinkedHashMap<>();
70+
void addMapFields(Map<String, Object> map) {
7171
map.put(resultsField, annotatedResult);
7272
map.put(ENTITY_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList()));
73-
return map;
7473
}
7574

7675
@Override
@@ -95,15 +94,16 @@ public String getAnnotatedResult() {
9594
public boolean equals(Object o) {
9695
if (this == o) return true;
9796
if (o == null || getClass() != o.getClass()) return false;
97+
if (super.equals(o) == false) return false;
9898
NerResults that = (NerResults) o;
99-
return Objects.equals(entityGroups, that.entityGroups)
100-
&& Objects.equals(resultsField, that.resultsField)
101-
&& Objects.equals(annotatedResult, that.annotatedResult);
99+
return Objects.equals(resultsField, that.resultsField)
100+
&& Objects.equals(annotatedResult, that.annotatedResult)
101+
&& Objects.equals(entityGroups, that.entityGroups);
102102
}
103103

104104
@Override
105105
public int hashCode() {
106-
return Objects.hash(entityGroups, resultsField, annotatedResult);
106+
return Objects.hash(super.hashCode(), resultsField, annotatedResult, entityGroups);
107107
}
108108

109109
public static class EntityGroup implements ToXContentObject, Writeable {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
package org.elasticsearch.xpack.core.ml.inference.results;
8+
9+
import org.elasticsearch.common.io.stream.StreamInput;
10+
import org.elasticsearch.common.io.stream.StreamOutput;
11+
import org.elasticsearch.xcontent.XContentBuilder;
12+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
13+
14+
import java.io.IOException;
15+
import java.util.Collections;
16+
import java.util.List;
17+
import java.util.Map;
18+
import java.util.Objects;
19+
import java.util.stream.Collectors;
20+
21+
public class NlpClassificationInferenceResults extends NlpInferenceResults {
22+
23+
public static final String NAME = "nlp_classification";
24+
25+
// Accessed in sub-classes
26+
protected final String resultsField;
27+
private final String classificationLabel;
28+
private final Double predictionProbability;
29+
private final List<TopClassEntry> topClasses;
30+
31+
public NlpClassificationInferenceResults(
32+
String classificationLabel,
33+
List<TopClassEntry> topClasses,
34+
String resultsField,
35+
Double predictionProbability,
36+
boolean isTruncated
37+
) {
38+
super(isTruncated);
39+
this.classificationLabel = Objects.requireNonNull(classificationLabel);
40+
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
41+
this.resultsField = resultsField;
42+
this.predictionProbability = predictionProbability;
43+
}
44+
45+
public NlpClassificationInferenceResults(StreamInput in) throws IOException {
46+
super(in);
47+
this.classificationLabel = in.readString();
48+
this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
49+
this.resultsField = in.readString();
50+
this.predictionProbability = in.readOptionalDouble();
51+
}
52+
53+
public String getClassificationLabel() {
54+
return classificationLabel;
55+
}
56+
57+
public List<TopClassEntry> getTopClasses() {
58+
return topClasses;
59+
}
60+
61+
@Override
62+
public void doWriteTo(StreamOutput out) throws IOException {
63+
out.writeString(classificationLabel);
64+
out.writeCollection(topClasses);
65+
out.writeString(resultsField);
66+
out.writeOptionalDouble(predictionProbability);
67+
}
68+
69+
@Override
70+
public boolean equals(Object o) {
71+
if (this == o) return true;
72+
if (o == null || getClass() != o.getClass()) return false;
73+
if (super.equals(o) == false) return false;
74+
NlpClassificationInferenceResults that = (NlpClassificationInferenceResults) o;
75+
return Objects.equals(resultsField, that.resultsField)
76+
&& Objects.equals(classificationLabel, that.classificationLabel)
77+
&& Objects.equals(predictionProbability, that.predictionProbability)
78+
&& Objects.equals(topClasses, that.topClasses);
79+
}
80+
81+
@Override
82+
public int hashCode() {
83+
return Objects.hash(super.hashCode(), resultsField, classificationLabel, predictionProbability, topClasses);
84+
}
85+
86+
public Double getPredictionProbability() {
87+
return predictionProbability;
88+
}
89+
90+
@Override
91+
public String getResultsField() {
92+
return resultsField;
93+
}
94+
95+
@Override
96+
public Object predictedValue() {
97+
return classificationLabel;
98+
}
99+
100+
@Override
101+
void addMapFields(Map<String, Object> map) {
102+
map.put(resultsField, classificationLabel);
103+
if (topClasses.isEmpty() == false) {
104+
map.put(
105+
NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD,
106+
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())
107+
);
108+
}
109+
if (predictionProbability != null) {
110+
map.put(PREDICTION_PROBABILITY, predictionProbability);
111+
}
112+
}
113+
114+
@Override
115+
public String getWriteableName() {
116+
return NAME;
117+
}
118+
119+
@Override
120+
public void doXContentBody(XContentBuilder builder, Params params) throws IOException {
121+
builder.field(resultsField, classificationLabel);
122+
if (topClasses.size() > 0) {
123+
builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses);
124+
}
125+
if (predictionProbability != null) {
126+
builder.field(PREDICTION_PROBABILITY, predictionProbability);
127+
}
128+
}
129+
}

0 commit comments

Comments
 (0)