19
19
20
20
package org .elasticsearch .index .similarity ;
21
21
22
+ import org .apache .logging .log4j .LogManager ;
23
+ import org .apache .lucene .index .FieldInvertState ;
24
+ import org .apache .lucene .index .IndexOptions ;
25
+ import org .apache .lucene .search .CollectionStatistics ;
26
+ import org .apache .lucene .search .Explanation ;
27
+ import org .apache .lucene .search .TermStatistics ;
22
28
import org .apache .lucene .search .similarities .BM25Similarity ;
23
29
import org .apache .lucene .search .similarities .BooleanSimilarity ;
24
30
import org .apache .lucene .search .similarities .ClassicSimilarity ;
25
31
import org .apache .lucene .search .similarities .PerFieldSimilarityWrapper ;
26
32
import org .apache .lucene .search .similarities .Similarity ;
33
+ import org .apache .lucene .search .similarities .Similarity .SimScorer ;
34
+ import org .apache .lucene .util .BytesRef ;
27
35
import org .elasticsearch .Version ;
28
36
import org .elasticsearch .common .TriFunction ;
29
37
import org .elasticsearch .common .logging .DeprecationLogger ;
30
- import org .elasticsearch .common .logging .Loggers ;
31
38
import org .elasticsearch .common .settings .Settings ;
32
39
import org .elasticsearch .index .AbstractIndexComponent ;
33
40
import org .elasticsearch .index .IndexModule ;
44
51
45
52
public final class SimilarityService extends AbstractIndexComponent {
46
53
47
- private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger (Loggers .getLogger (SimilarityService .class ));
54
+ private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger (LogManager .getLogger (SimilarityService .class ));
48
55
public static final String DEFAULT_SIMILARITY = "BM25" ;
49
56
private static final String CLASSIC_SIMILARITY = "classic" ;
50
57
private static final Map <String , Function <Version , Supplier <Similarity >>> DEFAULTS ;
@@ -131,8 +138,14 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic
131
138
}
132
139
TriFunction <Settings , Version , ScriptService , Similarity > defaultFactory = BUILT_IN .get (typeName );
133
140
TriFunction <Settings , Version , ScriptService , Similarity > factory = similarities .getOrDefault (typeName , defaultFactory );
134
- final Similarity similarity = factory .apply (providerSettings , indexSettings .getIndexVersionCreated (), scriptService );
135
- providers .put (name , () -> similarity );
141
+ Similarity similarity = factory .apply (providerSettings , indexSettings .getIndexVersionCreated (), scriptService );
142
+ validateSimilarity (indexSettings .getIndexVersionCreated (), similarity );
143
+ if (BUILT_IN .containsKey (typeName ) == false || "scripted" .equals (typeName )) {
144
+ // We don't trust custom similarities
145
+ similarity = new NonNegativeScoresSimilarity (similarity );
146
+ }
147
+ final Similarity similarityF = similarity ; // like similarity but final
148
+ providers .put (name , () -> similarityF );
136
149
}
137
150
for (Map .Entry <String , Function <Version , Supplier <Similarity >>> entry : DEFAULTS .entrySet ()) {
138
151
providers .put (entry .getKey (), entry .getValue ().apply (indexSettings .getIndexVersionCreated ()));
@@ -151,7 +164,7 @@ public Similarity similarity(MapperService mapperService) {
151
164
defaultSimilarity ;
152
165
}
153
166
154
-
167
+
155
168
public SimilarityProvider getSimilarity (String name ) {
156
169
Supplier <Similarity > sim = similarities .get (name );
157
170
if (sim == null ) {
@@ -182,4 +195,80 @@ public Similarity get(String name) {
182
195
return (fieldType != null && fieldType .similarity () != null ) ? fieldType .similarity ().get () : defaultSimilarity ;
183
196
}
184
197
}
198
+
199
+ static void validateSimilarity (Version indexCreatedVersion , Similarity similarity ) {
200
+ validateScoresArePositive (indexCreatedVersion , similarity );
201
+ validateScoresDoNotDecreaseWithFreq (indexCreatedVersion , similarity );
202
+ validateScoresDoNotIncreaseWithNorm (indexCreatedVersion , similarity );
203
+ }
204
+
205
+ private static void validateScoresArePositive (Version indexCreatedVersion , Similarity similarity ) {
206
+ CollectionStatistics collectionStats = new CollectionStatistics ("some_field" , 1200 , 1100 , 3000 , 2000 );
207
+ TermStatistics termStats = new TermStatistics (new BytesRef ("some_value" ), 100 , 130 );
208
+ SimScorer scorer = similarity .scorer (2f , collectionStats , termStats );
209
+ FieldInvertState state = new FieldInvertState (indexCreatedVersion .major , "some_field" ,
210
+ IndexOptions .DOCS_AND_FREQS , 20 , 20 , 0 , 50 , 10 , 3 ); // length = 20, no overlap
211
+ final long norm = similarity .computeNorm (state );
212
+ for (int freq = 1 ; freq <= 10 ; ++freq ) {
213
+ float score = scorer .score (freq , norm );
214
+ if (score < 0 ) {
215
+ fail (indexCreatedVersion , "Similarities should not return negative scores:\n " +
216
+ scorer .explain (Explanation .match (freq , "term freq" ), norm ));
217
+ }
218
+ }
219
+ }
220
+
221
+ private static void validateScoresDoNotDecreaseWithFreq (Version indexCreatedVersion , Similarity similarity ) {
222
+ CollectionStatistics collectionStats = new CollectionStatistics ("some_field" , 1200 , 1100 , 3000 , 2000 );
223
+ TermStatistics termStats = new TermStatistics (new BytesRef ("some_value" ), 100 , 130 );
224
+ SimScorer scorer = similarity .scorer (2f , collectionStats , termStats );
225
+ FieldInvertState state = new FieldInvertState (indexCreatedVersion .major , "some_field" ,
226
+ IndexOptions .DOCS_AND_FREQS , 20 , 20 , 0 , 50 , 10 , 3 ); // length = 20, no overlap
227
+ final long norm = similarity .computeNorm (state );
228
+ float previousScore = 0 ;
229
+ for (int freq = 1 ; freq <= 10 ; ++freq ) {
230
+ float score = scorer .score (freq , norm );
231
+ if (score < previousScore ) {
232
+ fail (indexCreatedVersion , "Similarity scores should not decrease when term frequency increases:\n " +
233
+ scorer .explain (Explanation .match (freq - 1 , "term freq" ), norm ) + "\n " +
234
+ scorer .explain (Explanation .match (freq , "term freq" ), norm ));
235
+ }
236
+ previousScore = score ;
237
+ }
238
+ }
239
+
240
+ private static void validateScoresDoNotIncreaseWithNorm (Version indexCreatedVersion , Similarity similarity ) {
241
+ CollectionStatistics collectionStats = new CollectionStatistics ("some_field" , 1200 , 1100 , 3000 , 2000 );
242
+ TermStatistics termStats = new TermStatistics (new BytesRef ("some_value" ), 100 , 130 );
243
+ SimScorer scorer = similarity .scorer (2f , collectionStats , termStats );
244
+
245
+ long previousNorm = 0 ;
246
+ float previousScore = Float .MAX_VALUE ;
247
+ for (int length = 1 ; length <= 10 ; ++length ) {
248
+ FieldInvertState state = new FieldInvertState (indexCreatedVersion .major , "some_field" ,
249
+ IndexOptions .DOCS_AND_FREQS , length , length , 0 , 50 , 10 , 3 ); // length = 20, no overlap
250
+ final long norm = similarity .computeNorm (state );
251
+ if (Long .compareUnsigned (previousNorm , norm ) > 0 ) {
252
+ // esoteric similarity, skip this check
253
+ break ;
254
+ }
255
+ float score = scorer .score (1 , norm );
256
+ if (score > previousScore ) {
257
+ fail (indexCreatedVersion , "Similarity scores should not increase when norm increases:\n " +
258
+ scorer .explain (Explanation .match (1 , "term freq" ), norm - 1 ) + "\n " +
259
+ scorer .explain (Explanation .match (1 , "term freq" ), norm ));
260
+ }
261
+ previousScore = score ;
262
+ previousNorm = norm ;
263
+ }
264
+ }
265
+
266
+ private static void fail (Version indexCreatedVersion , String message ) {
267
+ if (indexCreatedVersion .onOrAfter (Version .V_7_0_0_alpha1 )) {
268
+ throw new IllegalArgumentException (message );
269
+ } else if (indexCreatedVersion .onOrAfter (Version .V_6_5_0 )) {
270
+ DEPRECATION_LOGGER .deprecated (message );
271
+ }
272
+ }
273
+
185
274
}
0 commit comments