Skip to content

Commit c4261ba

Browse files
authored
Add minimal sanity checks to custom/scripted similarities. (#33564)
Add minimal sanity checks to custom/scripted similarities. Lucene 8 introduced more constraints on similarities, in particular: - scores must not be negative, - scores must not decrease when term freq increases, - scores must not increase when norm (interpreted as an unsigned long) increases. We can't check every single case, but could at least run some sanity checks. Relates #33309
1 parent 7f473b6 commit c4261ba

File tree

6 files changed

+339
-10
lines changed

6 files changed

+339
-10
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.index.similarity;
21+
22+
import org.apache.lucene.index.FieldInvertState;
23+
import org.apache.lucene.search.CollectionStatistics;
24+
import org.apache.lucene.search.Explanation;
25+
import org.apache.lucene.search.TermStatistics;
26+
import org.apache.lucene.search.similarities.Similarity;
27+
28+
/**
29+
* A {@link Similarity} that rejects negative scores. This class exists so that users get
30+
* an error instead of silently corrupt top hits. It should be applied to any custom or
31+
* scripted similarity.
32+
*/
33+
// public for testing
34+
public final class NonNegativeScoresSimilarity extends Similarity {
35+
36+
// Escape hatch
37+
private static final String ES_ENFORCE_POSITIVE_SCORES = "es.enforce.positive.scores";
38+
private static final boolean ENFORCE_POSITIVE_SCORES;
39+
static {
40+
String enforcePositiveScores = System.getProperty(ES_ENFORCE_POSITIVE_SCORES);
41+
if (enforcePositiveScores == null) {
42+
ENFORCE_POSITIVE_SCORES = true;
43+
} else if ("false".equals(enforcePositiveScores)) {
44+
ENFORCE_POSITIVE_SCORES = false;
45+
} else {
46+
throw new IllegalArgumentException(ES_ENFORCE_POSITIVE_SCORES + " may only be unset or set to [false], but got [" +
47+
enforcePositiveScores + "]");
48+
}
49+
}
50+
51+
private final Similarity in;
52+
53+
public NonNegativeScoresSimilarity(Similarity in) {
54+
this.in = in;
55+
}
56+
57+
public Similarity getDelegate() {
58+
return in;
59+
}
60+
61+
@Override
62+
public long computeNorm(FieldInvertState state) {
63+
return in.computeNorm(state);
64+
}
65+
66+
@Override
67+
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
68+
final SimScorer inScorer = in.scorer(boost, collectionStats, termStats);
69+
return new SimScorer() {
70+
71+
@Override
72+
public float score(float freq, long norm) {
73+
float score = inScorer.score(freq, norm);
74+
if (score < 0f) {
75+
if (ENFORCE_POSITIVE_SCORES) {
76+
throw new IllegalArgumentException("Similarities must not produce negative scores, but got:\n" +
77+
inScorer.explain(Explanation.match(freq, "term frequency"), norm));
78+
} else {
79+
return 0f;
80+
}
81+
}
82+
return score;
83+
}
84+
85+
@Override
86+
public Explanation explain(Explanation freq, long norm) {
87+
Explanation expl = inScorer.explain(freq, norm);
88+
if (expl.isMatch() && expl.getValue().floatValue() < 0) {
89+
expl = Explanation.match(0f, "max of:",
90+
expl, Explanation.match(0f, "Minimum allowed score"));
91+
}
92+
return expl;
93+
}
94+
};
95+
}
96+
}

server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,22 @@
1919

2020
package org.elasticsearch.index.similarity;
2121

22+
import org.apache.logging.log4j.LogManager;
23+
import org.apache.lucene.index.FieldInvertState;
24+
import org.apache.lucene.index.IndexOptions;
25+
import org.apache.lucene.search.CollectionStatistics;
26+
import org.apache.lucene.search.Explanation;
27+
import org.apache.lucene.search.TermStatistics;
2228
import org.apache.lucene.search.similarities.BM25Similarity;
2329
import org.apache.lucene.search.similarities.BooleanSimilarity;
2430
import org.apache.lucene.search.similarities.ClassicSimilarity;
2531
import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper;
2632
import org.apache.lucene.search.similarities.Similarity;
33+
import org.apache.lucene.search.similarities.Similarity.SimScorer;
34+
import org.apache.lucene.util.BytesRef;
2735
import org.elasticsearch.Version;
2836
import org.elasticsearch.common.TriFunction;
2937
import org.elasticsearch.common.logging.DeprecationLogger;
30-
import org.elasticsearch.common.logging.Loggers;
3138
import org.elasticsearch.common.settings.Settings;
3239
import org.elasticsearch.index.AbstractIndexComponent;
3340
import org.elasticsearch.index.IndexModule;
@@ -44,7 +51,7 @@
4451

4552
public final class SimilarityService extends AbstractIndexComponent {
4653

47-
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class));
54+
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class));
4855
public static final String DEFAULT_SIMILARITY = "BM25";
4956
private static final String CLASSIC_SIMILARITY = "classic";
5057
private static final Map<String, Function<Version, Supplier<Similarity>>> DEFAULTS;
@@ -131,8 +138,14 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic
131138
}
132139
TriFunction<Settings, Version, ScriptService, Similarity> defaultFactory = BUILT_IN.get(typeName);
133140
TriFunction<Settings, Version, ScriptService, Similarity> factory = similarities.getOrDefault(typeName, defaultFactory);
134-
final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
135-
providers.put(name, () -> similarity);
141+
Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
142+
validateSimilarity(indexSettings.getIndexVersionCreated(), similarity);
143+
if (BUILT_IN.containsKey(typeName) == false || "scripted".equals(typeName)) {
144+
// We don't trust custom similarities
145+
similarity = new NonNegativeScoresSimilarity(similarity);
146+
}
147+
final Similarity similarityF = similarity; // like similarity but final
148+
providers.put(name, () -> similarityF);
136149
}
137150
for (Map.Entry<String, Function<Version, Supplier<Similarity>>> entry : DEFAULTS.entrySet()) {
138151
providers.put(entry.getKey(), entry.getValue().apply(indexSettings.getIndexVersionCreated()));
@@ -151,7 +164,7 @@ public Similarity similarity(MapperService mapperService) {
151164
defaultSimilarity;
152165
}
153166

154-
167+
155168
public SimilarityProvider getSimilarity(String name) {
156169
Supplier<Similarity> sim = similarities.get(name);
157170
if (sim == null) {
@@ -182,4 +195,80 @@ public Similarity get(String name) {
182195
return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity;
183196
}
184197
}
198+
199+
static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) {
200+
validateScoresArePositive(indexCreatedVersion, similarity);
201+
validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity);
202+
validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity);
203+
}
204+
205+
private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) {
206+
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
207+
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
208+
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
209+
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
210+
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
211+
final long norm = similarity.computeNorm(state);
212+
for (int freq = 1; freq <= 10; ++freq) {
213+
float score = scorer.score(freq, norm);
214+
if (score < 0) {
215+
fail(indexCreatedVersion, "Similarities should not return negative scores:\n" +
216+
scorer.explain(Explanation.match(freq, "term freq"), norm));
217+
}
218+
}
219+
}
220+
221+
private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) {
222+
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
223+
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
224+
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
225+
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
226+
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
227+
final long norm = similarity.computeNorm(state);
228+
float previousScore = 0;
229+
for (int freq = 1; freq <= 10; ++freq) {
230+
float score = scorer.score(freq, norm);
231+
if (score < previousScore) {
232+
fail(indexCreatedVersion, "Similarity scores should not decrease when term frequency increases:\n" +
233+
scorer.explain(Explanation.match(freq - 1, "term freq"), norm) + "\n" +
234+
scorer.explain(Explanation.match(freq, "term freq"), norm));
235+
}
236+
previousScore = score;
237+
}
238+
}
239+
240+
private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) {
241+
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
242+
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
243+
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
244+
245+
long previousNorm = 0;
246+
float previousScore = Float.MAX_VALUE;
247+
for (int length = 1; length <= 10; ++length) {
248+
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
249+
IndexOptions.DOCS_AND_FREQS, length, length, 0, 50, 10, 3); // length = 20, no overlap
250+
final long norm = similarity.computeNorm(state);
251+
if (Long.compareUnsigned(previousNorm, norm) > 0) {
252+
// esoteric similarity, skip this check
253+
break;
254+
}
255+
float score = scorer.score(1, norm);
256+
if (score > previousScore) {
257+
fail(indexCreatedVersion, "Similarity scores should not increase when norm increases:\n" +
258+
scorer.explain(Explanation.match(1, "term freq"), norm - 1) + "\n" +
259+
scorer.explain(Explanation.match(1, "term freq"), norm));
260+
}
261+
previousScore = score;
262+
previousNorm = norm;
263+
}
264+
}
265+
266+
private static void fail(Version indexCreatedVersion, String message) {
267+
if (indexCreatedVersion.onOrAfter(Version.V_7_0_0_alpha1)) {
268+
throw new IllegalArgumentException(message);
269+
} else if (indexCreatedVersion.onOrAfter(Version.V_6_5_0)) {
270+
DEPRECATION_LOGGER.deprecated(message);
271+
}
272+
}
273+
185274
}

server/src/test/java/org/elasticsearch/index/IndexModuleTests.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import org.elasticsearch.index.shard.IndexingOperationListener;
6060
import org.elasticsearch.index.shard.SearchOperationListener;
6161
import org.elasticsearch.index.shard.ShardId;
62+
import org.elasticsearch.index.similarity.NonNegativeScoresSimilarity;
6263
import org.elasticsearch.index.similarity.SimilarityService;
6364
import org.elasticsearch.index.store.IndexStore;
6465
import org.elasticsearch.indices.IndicesModule;
@@ -77,6 +78,7 @@
7778
import org.elasticsearch.test.engine.MockEngineFactory;
7879
import org.elasticsearch.threadpool.TestThreadPool;
7980
import org.elasticsearch.threadpool.ThreadPool;
81+
import org.hamcrest.Matchers;
8082

8183
import java.io.IOException;
8284
import java.util.Collections;
@@ -295,10 +297,13 @@ public void testAddSimilarity() throws IOException {
295297

296298
IndexService indexService = newIndexService(module);
297299
SimilarityService similarityService = indexService.similarityService();
298-
assertNotNull(similarityService.getSimilarity("my_similarity"));
299-
assertTrue(similarityService.getSimilarity("my_similarity").get() instanceof TestSimilarity);
300+
Similarity similarity = similarityService.getSimilarity("my_similarity").get();
301+
assertNotNull(similarity);
302+
assertThat(similarity, Matchers.instanceOf(NonNegativeScoresSimilarity.class));
303+
similarity = ((NonNegativeScoresSimilarity) similarity).getDelegate();
304+
assertThat(similarity, Matchers.instanceOf(TestSimilarity.class));
300305
assertEquals("my_similarity", similarityService.getSimilarity("my_similarity").name());
301-
assertEquals("there is a key", ((TestSimilarity) similarityService.getSimilarity("my_similarity").get()).key);
306+
assertEquals("there is a key", ((TestSimilarity) similarity).key);
302307
indexService.close("simon says", false);
303308
}
304309

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.index.similarity;
21+
22+
import org.apache.lucene.index.FieldInvertState;
23+
import org.apache.lucene.search.CollectionStatistics;
24+
import org.apache.lucene.search.TermStatistics;
25+
import org.apache.lucene.search.similarities.Similarity;
26+
import org.apache.lucene.search.similarities.Similarity.SimScorer;
27+
import org.elasticsearch.test.ESTestCase;
28+
import org.hamcrest.Matchers;
29+
30+
public class NonNegativeScoresSimilarityTests extends ESTestCase {
31+
32+
public void testBasics() {
33+
Similarity negativeScoresSim = new Similarity() {
34+
35+
@Override
36+
public long computeNorm(FieldInvertState state) {
37+
return state.getLength();
38+
}
39+
40+
@Override
41+
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
42+
return new SimScorer() {
43+
@Override
44+
public float score(float freq, long norm) {
45+
return freq - 5;
46+
}
47+
};
48+
}
49+
};
50+
Similarity assertingSimilarity = new NonNegativeScoresSimilarity(negativeScoresSim);
51+
SimScorer scorer = assertingSimilarity.scorer(1f, null);
52+
assertEquals(2f, scorer.score(7f, 1L), 0f);
53+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> scorer.score(2f, 1L));
54+
assertThat(e.getMessage(), Matchers.containsString("Similarities must not produce negative scores"));
55+
}
56+
57+
}

0 commit comments

Comments
 (0)