Skip to content

Commit a957866

Browse files
authored
Add minimal sanity checks to custom/scripted similarities. (backport) (#33893)
Add minimal sanity checks to custom/scripted similarities. Lucene 8 introduced more constraints on similarities, in particular: - scores must not be negative, - scores must not decrease when term freq increases, - scores must not increase when norm (interpreted as an unsigned long) increases. We can't check every single case, but could at least run some sanity checks. Backport of #33564
1 parent 5ddd512 commit a957866

File tree

4 files changed

+310
-9
lines changed

4 files changed

+310
-9
lines changed

server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java

Lines changed: 254 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,35 @@
1919

2020
package org.elasticsearch.index.similarity;
2121

22+
import org.apache.logging.log4j.LogManager;
23+
import org.apache.lucene.index.BinaryDocValues;
24+
import org.apache.lucene.index.FieldInfos;
25+
import org.apache.lucene.index.FieldInvertState;
26+
import org.apache.lucene.index.Fields;
27+
import org.apache.lucene.index.LeafMetaData;
28+
import org.apache.lucene.index.LeafReader;
29+
import org.apache.lucene.index.NumericDocValues;
30+
import org.apache.lucene.index.PointValues;
31+
import org.apache.lucene.index.SortedDocValues;
32+
import org.apache.lucene.index.SortedNumericDocValues;
33+
import org.apache.lucene.index.SortedSetDocValues;
34+
import org.apache.lucene.index.StoredFieldVisitor;
35+
import org.apache.lucene.index.Terms;
36+
import org.apache.lucene.search.CollectionStatistics;
37+
import org.apache.lucene.search.Explanation;
38+
import org.apache.lucene.search.TermStatistics;
2239
import org.apache.lucene.search.similarities.BM25Similarity;
2340
import org.apache.lucene.search.similarities.BooleanSimilarity;
2441
import org.apache.lucene.search.similarities.ClassicSimilarity;
2542
import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper;
2643
import org.apache.lucene.search.similarities.Similarity;
44+
import org.apache.lucene.search.similarities.Similarity.SimScorer;
45+
import org.apache.lucene.search.similarities.Similarity.SimWeight;
46+
import org.apache.lucene.util.Bits;
47+
import org.apache.lucene.util.BytesRef;
2748
import org.elasticsearch.Version;
2849
import org.elasticsearch.common.TriFunction;
2950
import org.elasticsearch.common.logging.DeprecationLogger;
30-
import org.elasticsearch.common.logging.Loggers;
3151
import org.elasticsearch.common.settings.Settings;
3252
import org.elasticsearch.index.AbstractIndexComponent;
3353
import org.elasticsearch.index.IndexModule;
@@ -36,6 +56,8 @@
3656
import org.elasticsearch.index.mapper.MapperService;
3757
import org.elasticsearch.script.ScriptService;
3858

59+
import java.io.IOException;
60+
import java.io.UncheckedIOException;
3961
import java.util.Collections;
4062
import java.util.HashMap;
4163
import java.util.Map;
@@ -44,7 +66,7 @@
4466

4567
public final class SimilarityService extends AbstractIndexComponent {
4668

47-
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class));
69+
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class));
4870
public static final String DEFAULT_SIMILARITY = "BM25";
4971
private static final String CLASSIC_SIMILARITY = "classic";
5072
private static final Map<String, Function<Version, Supplier<Similarity>>> DEFAULTS;
@@ -120,7 +142,8 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic
120142
}
121143
TriFunction<Settings, Version, ScriptService, Similarity> defaultFactory = BUILT_IN.get(typeName);
122144
TriFunction<Settings, Version, ScriptService, Similarity> factory = similarities.getOrDefault(typeName, defaultFactory);
123-
final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
145+
Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
146+
validateSimilarity(indexSettings.getIndexVersionCreated(), similarity);
124147
providers.put(name, () -> similarity);
125148
}
126149
for (Map.Entry<String, Function<Version, Supplier<Similarity>>> entry : DEFAULTS.entrySet()) {
@@ -140,7 +163,7 @@ public Similarity similarity(MapperService mapperService) {
140163
defaultSimilarity;
141164
}
142165

143-
166+
144167
public SimilarityProvider getSimilarity(String name) {
145168
Supplier<Similarity> sim = similarities.get(name);
146169
if (sim == null) {
@@ -171,4 +194,231 @@ public Similarity get(String name) {
171194
return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity;
172195
}
173196
}
197+
198+
static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) {
199+
try {
200+
validateScoresArePositive(indexCreatedVersion, similarity);
201+
validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity);
202+
validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity);
203+
} catch (IOException e) {
204+
throw new UncheckedIOException(e);
205+
}
206+
}
207+
208+
private static class SingleNormLeafReader extends LeafReader {
209+
210+
private final long norm;
211+
212+
SingleNormLeafReader(long norm) {
213+
this.norm = norm;
214+
}
215+
216+
@Override
217+
public CacheHelper getCoreCacheHelper() {
218+
return null;
219+
}
220+
221+
@Override
222+
public Terms terms(String field) throws IOException {
223+
throw new UnsupportedOperationException();
224+
}
225+
226+
@Override
227+
public NumericDocValues getNumericDocValues(String field) throws IOException {
228+
throw new UnsupportedOperationException();
229+
}
230+
231+
@Override
232+
public BinaryDocValues getBinaryDocValues(String field) throws IOException {
233+
throw new UnsupportedOperationException();
234+
}
235+
236+
@Override
237+
public SortedDocValues getSortedDocValues(String field) throws IOException {
238+
throw new UnsupportedOperationException();
239+
}
240+
241+
@Override
242+
public SortedNumericDocValues getSortedNumericDocValues(String field) throws IOException {
243+
throw new UnsupportedOperationException();
244+
}
245+
246+
@Override
247+
public SortedSetDocValues getSortedSetDocValues(String field) throws IOException {
248+
throw new UnsupportedOperationException();
249+
}
250+
251+
@Override
252+
public NumericDocValues getNormValues(String field) throws IOException {
253+
return new NumericDocValues() {
254+
255+
int doc = -1;
256+
257+
@Override
258+
public long longValue() throws IOException {
259+
return norm;
260+
}
261+
262+
@Override
263+
public boolean advanceExact(int target) throws IOException {
264+
doc = target;
265+
return true;
266+
}
267+
268+
@Override
269+
public int docID() {
270+
return doc;
271+
}
272+
273+
@Override
274+
public int nextDoc() throws IOException {
275+
return advance(doc + 1);
276+
}
277+
278+
@Override
279+
public int advance(int target) throws IOException {
280+
if (target == 0) {
281+
return doc = 0;
282+
} else {
283+
return doc = NO_MORE_DOCS;
284+
}
285+
}
286+
287+
@Override
288+
public long cost() {
289+
return 1;
290+
}
291+
292+
};
293+
}
294+
295+
@Override
296+
public FieldInfos getFieldInfos() {
297+
throw new UnsupportedOperationException();
298+
}
299+
300+
@Override
301+
public Bits getLiveDocs() {
302+
return null;
303+
}
304+
305+
@Override
306+
public PointValues getPointValues(String field) throws IOException {
307+
throw new UnsupportedOperationException();
308+
}
309+
310+
@Override
311+
public void checkIntegrity() throws IOException {}
312+
313+
@Override
314+
public LeafMetaData getMetaData() {
315+
return new LeafMetaData(
316+
org.apache.lucene.util.Version.LATEST.major,
317+
org.apache.lucene.util.Version.LATEST,
318+
null);
319+
}
320+
321+
@Override
322+
public Fields getTermVectors(int docID) throws IOException {
323+
throw new UnsupportedOperationException();
324+
}
325+
326+
@Override
327+
public int numDocs() {
328+
return 1;
329+
}
330+
331+
@Override
332+
public int maxDoc() {
333+
return 1;
334+
}
335+
336+
@Override
337+
public void document(int docID, StoredFieldVisitor visitor) throws IOException {
338+
throw new UnsupportedOperationException();
339+
}
340+
341+
@Override
342+
protected void doClose() throws IOException {
343+
}
344+
345+
@Override
346+
public CacheHelper getReaderCacheHelper() {
347+
throw new UnsupportedOperationException();
348+
}
349+
350+
}
351+
352+
private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) throws IOException {
353+
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
354+
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
355+
SimWeight simWeight = similarity.computeWeight(2f, collectionStats, termStats);
356+
FieldInvertState state = new FieldInvertState(indexCreatedVersion.luceneVersion.major,
357+
"some_field", 20, 20, 0, 50); // length = 20, no overlap
358+
final long norm = similarity.computeNorm(state);
359+
LeafReader reader = new SingleNormLeafReader(norm);
360+
SimScorer scorer = similarity.simScorer(simWeight, reader.getContext());
361+
for (int freq = 1; freq <= 10; ++freq) {
362+
float score = scorer.score(0, freq);
363+
if (score < 0) {
364+
DEPRECATION_LOGGER.deprecated("Similarities should not return negative scores:\n" +
365+
scorer.explain(0, Explanation.match(freq, "term freq")));
366+
break;
367+
}
368+
}
369+
}
370+
371+
private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) throws IOException {
372+
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
373+
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
374+
SimWeight simWeight = similarity.computeWeight(2f, collectionStats, termStats);
375+
FieldInvertState state = new FieldInvertState(indexCreatedVersion.luceneVersion.major,
376+
"some_field", 20, 20, 0, 50); // length = 20, no overlap
377+
final long norm = similarity.computeNorm(state);
378+
LeafReader reader = new SingleNormLeafReader(norm);
379+
SimScorer scorer = similarity.simScorer(simWeight, reader.getContext());
380+
float previousScore = Float.NEGATIVE_INFINITY;
381+
for (int freq = 1; freq <= 10; ++freq) {
382+
float score = scorer.score(0, freq);
383+
if (score < previousScore) {
384+
DEPRECATION_LOGGER.deprecated("Similarity scores should not decrease when term frequency increases:\n" +
385+
scorer.explain(0, Explanation.match(freq - 1, "term freq")) + "\n" +
386+
scorer.explain(0, Explanation.match(freq, "term freq")));
387+
break;
388+
}
389+
previousScore = score;
390+
}
391+
}
392+
393+
private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) throws IOException {
394+
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
395+
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
396+
SimWeight simWeight = similarity.computeWeight(2f, collectionStats, termStats);
397+
398+
SimScorer previousScorer = null;
399+
long previousNorm = 0;
400+
float previousScore = Float.POSITIVE_INFINITY;
401+
for (int length = 1; length <= 10; ++length) {
402+
FieldInvertState state = new FieldInvertState(indexCreatedVersion.luceneVersion.major,
403+
"some_field", length, length, 0, 50); // length = 20, no overlap
404+
final long norm = similarity.computeNorm(state);
405+
if (Long.compareUnsigned(previousNorm, norm) > 0) {
406+
// esoteric similarity, skip this check
407+
break;
408+
}
409+
LeafReader reader = new SingleNormLeafReader(norm);
410+
SimScorer scorer = similarity.simScorer(simWeight, reader.getContext());
411+
float score = scorer.score(0, 1);
412+
if (score > previousScore) {
413+
DEPRECATION_LOGGER.deprecated("Similarity scores should not increase when norm increases:\n" +
414+
previousScorer.explain(0, Explanation.match(1, "term freq")) + "\n" +
415+
scorer.explain(0, Explanation.match(1, "term freq")));
416+
break;
417+
}
418+
previousScorer = scorer;
419+
previousScore = score;
420+
previousNorm = norm;
421+
}
422+
}
423+
174424
}

server/src/test/java/org/elasticsearch/index/IndexModuleTests.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
import org.elasticsearch.test.engine.MockEngineFactory;
7979
import org.elasticsearch.threadpool.TestThreadPool;
8080
import org.elasticsearch.threadpool.ThreadPool;
81+
import org.hamcrest.Matchers;
8182

8283
import java.io.IOException;
8384
import java.util.Collections;
@@ -297,10 +298,11 @@ public void testAddSimilarity() throws IOException {
297298

298299
IndexService indexService = newIndexService(module);
299300
SimilarityService similarityService = indexService.similarityService();
300-
assertNotNull(similarityService.getSimilarity("my_similarity"));
301-
assertTrue(similarityService.getSimilarity("my_similarity").get() instanceof TestSimilarity);
301+
Similarity similarity = similarityService.getSimilarity("my_similarity").get();
302+
assertNotNull(similarity);
303+
assertThat(similarity, Matchers.instanceOf(TestSimilarity.class));
302304
assertEquals("my_similarity", similarityService.getSimilarity("my_similarity").name());
303-
assertEquals("there is a key", ((TestSimilarity) similarityService.getSimilarity("my_similarity").get()).key);
305+
assertEquals("there is a key", ((TestSimilarity) similarity).key);
304306
indexService.close("simon says", false);
305307
}
306308

server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
package org.elasticsearch.index.similarity;
2020

2121
import org.apache.lucene.search.similarities.BM25Similarity;
22+
import org.apache.lucene.search.similarities.BasicStats;
2223
import org.apache.lucene.search.similarities.BooleanSimilarity;
24+
import org.apache.lucene.search.similarities.Similarity;
25+
import org.apache.lucene.search.similarities.SimilarityBase;
26+
import org.elasticsearch.Version;
2327
import org.elasticsearch.common.settings.Settings;
2428
import org.elasticsearch.index.IndexSettings;
2529
import org.elasticsearch.test.ESTestCase;
@@ -56,4 +60,48 @@ public void testOverrideDefaultSimilarity() {
5660
SimilarityService service = new SimilarityService(indexSettings, null, Collections.emptyMap());
5761
assertTrue(service.getDefaultSimilarity() instanceof BooleanSimilarity);
5862
}
63+
64+
public void testSimilarityValidation() {
65+
Similarity negativeScoresSim = new SimilarityBase() {
66+
@Override
67+
public String toString() {
68+
return "negativeScoresSim";
69+
}
70+
@Override
71+
protected float score(BasicStats stats, float freq, float docLen) {
72+
return -1;
73+
}
74+
};
75+
SimilarityService.validateSimilarity(Version.V_6_5_0, negativeScoresSim);
76+
assertWarnings("Similarities should not return negative scores:\n-1.0 = score(, doc=0, freq=1.0), computed from:\n");
77+
78+
Similarity decreasingScoresWithFreqSim = new SimilarityBase() {
79+
@Override
80+
public String toString() {
81+
return "decreasingScoresWithFreqSim";
82+
}
83+
@Override
84+
protected float score(BasicStats stats, float freq, float docLen) {
85+
return 1 / (freq + docLen);
86+
}
87+
};
88+
SimilarityService.validateSimilarity(Version.V_6_5_0, decreasingScoresWithFreqSim);
89+
assertWarnings("Similarity scores should not decrease when term frequency increases:\n0.04761905 = score(, doc=0, freq=1.0), " +
90+
"computed from:\n\n0.045454547 = score(, doc=0, freq=2.0), computed from:\n");
91+
92+
Similarity increasingScoresWithNormSim = new SimilarityBase() {
93+
@Override
94+
public String toString() {
95+
return "increasingScoresWithNormSim";
96+
}
97+
@Override
98+
protected float score(BasicStats stats, float freq, float docLen) {
99+
return freq + docLen;
100+
}
101+
};
102+
SimilarityService.validateSimilarity(Version.V_6_5_0, increasingScoresWithNormSim);
103+
assertWarnings("Similarity scores should not increase when norm increases:\n2.0 = score(, doc=0, freq=1.0), " +
104+
"computed from:\n\n3.0 = score(, doc=0, freq=1.0), computed from:\n");
105+
}
106+
59107
}

0 commit comments

Comments
 (0)