Skip to content

Commit b9aebc8

Browse files
authored
[ML] fix NLP tokenization never_split handling around punctuation (#82982)
When multiple characters in a row might be part of the never_split we erroneously tokenized them. This commit handles this scenario so now `[[UNK]` is now tokenized as `[`, `[UNK]`
1 parent 002f506 commit b9aebc8

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizer.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,13 @@ private List<DelimitedToken> mergeNeverSplitTokens(String originalText, List<Del
128128
matchingTokens = new ArrayList<>();
129129
current = neverSplitTokenTrieRoot;
130130
}
131-
mergedTokens.add(token);
131+
childNode = current.getChild(token.getToken());
132+
if (childNode == null) {
133+
mergedTokens.add(token);
134+
} else {
135+
matchingTokens.add(token);
136+
current = childNode;
137+
}
132138
} else if (childNode.isLeaf()) {
133139
matchingTokens.add(token);
134140
DelimitedToken mergedToken = DelimitedToken.mergeTokens(matchingTokens);

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenizerTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ public void testNeverSplit_GivenNoLowerCase() {
7979
assertThat(tokenStrings(tokens), contains("Hello", "-", "[UNK]"));
8080
tokens = tokenizer.tokenize("Hello~[UNK][UNK]");
8181
assertThat(tokenStrings(tokens), contains("Hello", "~", "[UNK]", "[UNK]"));
82+
assertThat(tokenStrings(tokenizer.tokenize("Hello~[[UNK]")), contains("Hello", "~", "[", "[UNK]"));
83+
assertThat(tokenStrings(tokenizer.tokenize("Hello~[[[UNK]")), contains("Hello", "~", "[", "[", "[UNK]"));
84+
assertThat(tokenStrings(tokenizer.tokenize("Hello~[UNK]]")), contains("Hello", "~", "[UNK]", "]"));
85+
assertThat(tokenStrings(tokenizer.tokenize("Hello~[UNK]]]")), contains("Hello", "~", "[UNK]", "]", "]"));
86+
assertThat(tokenStrings(tokenizer.tokenize("Hello~[[UNK]]")), contains("Hello", "~", "[", "[UNK]", "]"));
8287
tokens = tokenizer.tokenize("Hello-[unk]");
8388
assertThat(tokenStrings(tokens), contains("Hello", "-", "[", "unk", "]"));
8489
}

0 commit comments

Comments
 (0)