Skip to content

Commit b2b3c9b

Browse files
committed
[ML] Allow NLP truncate option to be updated when span is set (elastic#91224)
1 parent 69cf4d2 commit b2b3c9b

File tree

9 files changed

+407
-189
lines changed

9 files changed

+407
-189
lines changed

docs/changelog/91224.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 91224
2+
summary: Allow NLP truncate option to be updated when span is set
3+
area: Machine Learning
4+
type: bug
5+
issues: []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
9+
10+
import org.elasticsearch.Version;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.core.Nullable;
14+
import org.elasticsearch.xcontent.ConstructingObjectParser;
15+
import org.elasticsearch.xcontent.XContentBuilder;
16+
17+
import java.io.IOException;
18+
import java.util.Objects;
19+
20+
public abstract class AbstractTokenizationUpdate implements TokenizationUpdate {
21+
22+
private final Tokenization.Truncate truncate;
23+
private final Integer span;
24+
25+
protected static void declareCommonParserFields(ConstructingObjectParser<? extends AbstractTokenizationUpdate, Void> parser) {
26+
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
27+
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
28+
}
29+
30+
public AbstractTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
31+
this.truncate = truncate;
32+
this.span = span;
33+
}
34+
35+
public AbstractTokenizationUpdate(StreamInput in) throws IOException {
36+
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
37+
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
38+
this.span = in.readOptionalInt();
39+
} else {
40+
this.span = null;
41+
}
42+
}
43+
44+
@Override
45+
public boolean isNoop() {
46+
return truncate == null && span == null;
47+
}
48+
49+
@Override
50+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
51+
builder.startObject();
52+
if (truncate != null) {
53+
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
54+
}
55+
if (span != null) {
56+
builder.field(Tokenization.SPAN.getPreferredName(), span);
57+
}
58+
builder.endObject();
59+
return builder;
60+
}
61+
62+
@Override
63+
public void writeTo(StreamOutput out) throws IOException {
64+
out.writeOptionalEnum(truncate);
65+
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
66+
out.writeOptionalInt(span);
67+
}
68+
}
69+
70+
public Integer getSpan() {
71+
return span;
72+
}
73+
74+
public Tokenization.Truncate getTruncate() {
75+
return truncate;
76+
}
77+
78+
@Override
79+
public boolean equals(Object o) {
80+
if (this == o) return true;
81+
if (o instanceof AbstractTokenizationUpdate == false) {
82+
return false;
83+
}
84+
AbstractTokenizationUpdate that = (AbstractTokenizationUpdate) o;
85+
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
86+
}
87+
88+
@Override
89+
public int hashCode() {
90+
return Objects.hash(truncate, span);
91+
}
92+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java

Lines changed: 21 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,17 @@
77

88
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
99

10-
import org.elasticsearch.Version;
1110
import org.elasticsearch.common.io.stream.StreamInput;
12-
import org.elasticsearch.common.io.stream.StreamOutput;
1311
import org.elasticsearch.core.Nullable;
1412
import org.elasticsearch.xcontent.ConstructingObjectParser;
1513
import org.elasticsearch.xcontent.ParseField;
16-
import org.elasticsearch.xcontent.XContentBuilder;
1714
import org.elasticsearch.xcontent.XContentParser;
1815
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1916

2017
import java.io.IOException;
21-
import java.util.Objects;
2218
import java.util.Optional;
2319

24-
public class BertTokenizationUpdate implements TokenizationUpdate {
20+
public class BertTokenizationUpdate extends AbstractTokenizationUpdate {
2521

2622
public static final ParseField NAME = BertTokenization.NAME;
2723

@@ -31,29 +27,19 @@ public class BertTokenizationUpdate implements TokenizationUpdate {
3127
);
3228

3329
static {
34-
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
35-
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
30+
declareCommonParserFields(PARSER);
3631
}
3732

3833
public static BertTokenizationUpdate fromXContent(XContentParser parser) {
3934
return PARSER.apply(parser, null);
4035
}
4136

42-
private final Tokenization.Truncate truncate;
43-
private final Integer span;
44-
4537
public BertTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
46-
this.truncate = truncate;
47-
this.span = span;
38+
super(truncate, span);
4839
}
4940

5041
public BertTokenizationUpdate(StreamInput in) throws IOException {
51-
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
52-
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
53-
this.span = in.readOptionalInt();
54-
} else {
55-
this.span = null;
56-
}
42+
super(in);
5743
}
5844

5945
@Override
@@ -66,65 +52,41 @@ public Tokenization apply(Tokenization originalConfig) {
6652
);
6753
}
6854

55+
Tokenization.validateSpanAndTruncate(getTruncate(), getSpan());
56+
6957
if (isNoop()) {
7058
return originalConfig;
7159
}
7260

61+
if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
62+
// When truncate value is incompatible with span wipe out
63+
// the existing span setting to avoid an invalid combination of settings.
64+
// This avoids the user have to set span to the special unset value
65+
return new BertTokenization(
66+
originalConfig.doLowerCase(),
67+
originalConfig.withSpecialTokens(),
68+
originalConfig.maxSequenceLength(),
69+
getTruncate(),
70+
null
71+
);
72+
}
73+
7374
return new BertTokenization(
7475
originalConfig.doLowerCase(),
7576
originalConfig.withSpecialTokens(),
7677
originalConfig.maxSequenceLength(),
77-
Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
78-
Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
78+
Optional.ofNullable(getTruncate()).orElse(originalConfig.getTruncate()),
79+
Optional.ofNullable(getSpan()).orElse(originalConfig.getSpan())
7980
);
8081
}
8182

82-
@Override
83-
public boolean isNoop() {
84-
return truncate == null && span == null;
85-
}
86-
87-
@Override
88-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
89-
builder.startObject();
90-
if (truncate != null) {
91-
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
92-
}
93-
if (span != null) {
94-
builder.field(Tokenization.SPAN.getPreferredName(), span);
95-
}
96-
builder.endObject();
97-
return builder;
98-
}
99-
10083
@Override
10184
public String getWriteableName() {
10285
return BertTokenization.NAME.getPreferredName();
10386
}
10487

105-
@Override
106-
public void writeTo(StreamOutput out) throws IOException {
107-
out.writeOptionalEnum(truncate);
108-
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
109-
out.writeOptionalInt(span);
110-
}
111-
}
112-
11388
@Override
11489
public String getName() {
11590
return BertTokenization.NAME.getPreferredName();
11691
}
117-
118-
@Override
119-
public boolean equals(Object o) {
120-
if (this == o) return true;
121-
if (o == null || getClass() != o.getClass()) return false;
122-
BertTokenizationUpdate that = (BertTokenizationUpdate) o;
123-
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
124-
}
125-
126-
@Override
127-
public int hashCode() {
128-
return Objects.hash(truncate, span);
129-
}
13092
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdate.java

Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,17 @@
77

88
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
99

10-
import org.elasticsearch.Version;
1110
import org.elasticsearch.common.io.stream.StreamInput;
12-
import org.elasticsearch.common.io.stream.StreamOutput;
1311
import org.elasticsearch.core.Nullable;
1412
import org.elasticsearch.xcontent.ConstructingObjectParser;
1513
import org.elasticsearch.xcontent.ParseField;
16-
import org.elasticsearch.xcontent.XContentBuilder;
1714
import org.elasticsearch.xcontent.XContentParser;
1815
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1916

2017
import java.io.IOException;
21-
import java.util.Objects;
2218
import java.util.Optional;
2319

24-
public class MPNetTokenizationUpdate implements TokenizationUpdate {
20+
public class MPNetTokenizationUpdate extends AbstractTokenizationUpdate {
2521

2622
public static final ParseField NAME = MPNetTokenization.NAME;
2723

@@ -31,29 +27,19 @@ public class MPNetTokenizationUpdate implements TokenizationUpdate {
3127
);
3228

3329
static {
34-
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
35-
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
30+
declareCommonParserFields(PARSER);
3631
}
3732

3833
public static MPNetTokenizationUpdate fromXContent(XContentParser parser) {
3934
return PARSER.apply(parser, null);
4035
}
4136

42-
private final Tokenization.Truncate truncate;
43-
private final Integer span;
44-
4537
public MPNetTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
46-
this.truncate = truncate;
47-
this.span = span;
38+
super(truncate, span);
4839
}
4940

5041
public MPNetTokenizationUpdate(StreamInput in) throws IOException {
51-
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
52-
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
53-
this.span = in.readOptionalInt();
54-
} else {
55-
this.span = null;
56-
}
42+
super(in);
5743
}
5844

5945
@Override
@@ -70,61 +56,35 @@ public Tokenization apply(Tokenization originalConfig) {
7056
return originalConfig;
7157
}
7258

59+
if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
60+
// When truncate value is incompatible with span wipe out
61+
// the existing span setting to avoid an invalid combination of settings.
62+
// This avoids the user have to set span to the special unset value
63+
return new MPNetTokenization(
64+
originalConfig.doLowerCase(),
65+
originalConfig.withSpecialTokens(),
66+
originalConfig.maxSequenceLength(),
67+
getTruncate(),
68+
null
69+
);
70+
}
71+
7372
return new MPNetTokenization(
7473
originalConfig.doLowerCase(),
7574
originalConfig.withSpecialTokens(),
7675
originalConfig.maxSequenceLength(),
77-
Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
78-
Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
76+
Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()),
77+
Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan())
7978
);
8079
}
8180

82-
@Override
83-
public boolean isNoop() {
84-
return truncate == null && span == null;
85-
}
86-
87-
@Override
88-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
89-
builder.startObject();
90-
if (truncate != null) {
91-
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
92-
}
93-
if (span != null) {
94-
builder.field(Tokenization.SPAN.getPreferredName(), span);
95-
}
96-
builder.endObject();
97-
return builder;
98-
}
99-
10081
@Override
10182
public String getWriteableName() {
10283
return MPNetTokenization.NAME.getPreferredName();
10384
}
10485

105-
@Override
106-
public void writeTo(StreamOutput out) throws IOException {
107-
out.writeOptionalEnum(truncate);
108-
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
109-
out.writeOptionalInt(span);
110-
}
111-
}
112-
11386
@Override
11487
public String getName() {
11588
return MPNetTokenization.NAME.getPreferredName();
11689
}
117-
118-
@Override
119-
public boolean equals(Object o) {
120-
if (this == o) return true;
121-
if (o == null || getClass() != o.getClass()) return false;
122-
MPNetTokenizationUpdate that = (MPNetTokenizationUpdate) o;
123-
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
124-
}
125-
126-
@Override
127-
public int hashCode() {
128-
return Objects.hash(truncate, span);
129-
}
13090
}

0 commit comments

Comments
 (0)