18
18
*/
19
19
package org .elasticsearch .search .suggest .phrase ;
20
20
21
+ import org .apache .lucene .index .IndexReader ;
22
+ import org .apache .lucene .index .Terms ;
23
+ import org .apache .lucene .util .BytesRef ;
21
24
import org .elasticsearch .common .ParseField ;
22
25
import org .elasticsearch .common .ParseFieldMatcher ;
26
+ import org .elasticsearch .common .ParsingException ;
23
27
import org .elasticsearch .common .io .stream .NamedWriteable ;
24
28
import org .elasticsearch .common .io .stream .StreamInput ;
25
29
import org .elasticsearch .common .io .stream .StreamOutput ;
30
34
import org .elasticsearch .index .query .QueryParseContext ;
31
35
import org .elasticsearch .script .Template ;
32
36
import org .elasticsearch .search .suggest .SuggestBuilder .SuggestionBuilder ;
37
+ import org .elasticsearch .search .suggest .phrase .WordScorer .WordScorerFactory ;
33
38
34
39
import java .io .IOException ;
35
40
import java .util .ArrayList ;
@@ -50,7 +55,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
50
55
private Float confidence ;
51
56
private final Map <String , List <CandidateGenerator >> generators = new HashMap <>();
52
57
private Integer gramSize ;
53
- private SmoothingModel <?> model ;
58
+ private SmoothingModel model ;
54
59
private Boolean forceUnigrams ;
55
60
private Integer tokenLimit ;
56
61
private String preTag ;
@@ -159,7 +164,7 @@ public PhraseSuggestionBuilder forceUnigrams(boolean forceUnigrams) {
159
164
* Sets an explicit smoothing model used for this suggester. The default is
160
165
* {@link PhraseSuggestionBuilder.StupidBackoff}.
161
166
*/
162
- public PhraseSuggestionBuilder smoothingModel (SmoothingModel <?> model ) {
167
+ public PhraseSuggestionBuilder smoothingModel (SmoothingModel model ) {
163
168
this .model = model ;
164
169
return this ;
165
170
}
@@ -292,7 +297,7 @@ public static DirectCandidateGenerator candidateGenerator(String field) {
292
297
* Smoothing</a> for details.
293
298
* </p>
294
299
*/
295
- public static final class StupidBackoff extends SmoothingModel < StupidBackoff > {
300
+ public static final class StupidBackoff extends SmoothingModel {
296
301
/**
297
302
* Default discount parameter for {@link StupidBackoff} smoothing
298
303
*/
@@ -341,8 +346,9 @@ public StupidBackoff readFrom(StreamInput in) throws IOException {
341
346
}
342
347
343
348
@ Override
344
- protected boolean doEquals (StupidBackoff other ) {
345
- return Objects .equals (discount , other .discount );
349
+ protected boolean doEquals (SmoothingModel other ) {
350
+ StupidBackoff otherModel = (StupidBackoff ) other ;
351
+ return Objects .equals (discount , otherModel .discount );
346
352
}
347
353
348
354
@ Override
@@ -351,7 +357,7 @@ public final int hashCode() {
351
357
}
352
358
353
359
@ Override
354
- public StupidBackoff fromXContent (QueryParseContext parseContext ) throws IOException {
360
+ public SmoothingModel fromXContent (QueryParseContext parseContext ) throws IOException {
355
361
XContentParser parser = parseContext .parser ();
356
362
XContentParser .Token token ;
357
363
String fieldName = null ;
@@ -366,6 +372,12 @@ public StupidBackoff fromXContent(QueryParseContext parseContext) throws IOExcep
366
372
}
367
373
return new StupidBackoff (discount );
368
374
}
375
+
376
+ @ Override
377
+ public WordScorerFactory buildWordScorerFactory () {
378
+ return (IndexReader reader , Terms terms , String field , double realWordLikelyhood , BytesRef separator )
379
+ -> new StupidBackoffScorer (reader , terms , field , realWordLikelyhood , separator , discount );
380
+ }
369
381
}
370
382
371
383
/**
@@ -377,7 +389,7 @@ public StupidBackoff fromXContent(QueryParseContext parseContext) throws IOExcep
377
389
* Smoothing</a> for details.
378
390
* </p>
379
391
*/
380
- public static final class Laplace extends SmoothingModel < Laplace > {
392
+ public static final class Laplace extends SmoothingModel {
381
393
private double alpha = DEFAULT_LAPLACE_ALPHA ;
382
394
private static final String NAME = "laplace" ;
383
395
private static final ParseField ALPHA_FIELD = new ParseField ("alpha" );
@@ -419,13 +431,14 @@ public void writeTo(StreamOutput out) throws IOException {
419
431
}
420
432
421
433
@ Override
422
- public Laplace readFrom (StreamInput in ) throws IOException {
434
+ public SmoothingModel readFrom (StreamInput in ) throws IOException {
423
435
return new Laplace (in .readDouble ());
424
436
}
425
437
426
438
@ Override
427
- protected boolean doEquals (Laplace other ) {
428
- return Objects .equals (alpha , other .alpha );
439
+ protected boolean doEquals (SmoothingModel other ) {
440
+ Laplace otherModel = (Laplace ) other ;
441
+ return Objects .equals (alpha , otherModel .alpha );
429
442
}
430
443
431
444
@ Override
@@ -434,7 +447,7 @@ public final int hashCode() {
434
447
}
435
448
436
449
@ Override
437
- public Laplace fromXContent (QueryParseContext parseContext ) throws IOException {
450
+ public SmoothingModel fromXContent (QueryParseContext parseContext ) throws IOException {
438
451
XContentParser parser = parseContext .parser ();
439
452
XContentParser .Token token ;
440
453
String fieldName = null ;
@@ -449,10 +462,16 @@ public Laplace fromXContent(QueryParseContext parseContext) throws IOException {
449
462
}
450
463
return new Laplace (alpha );
451
464
}
465
+
466
+ @ Override
467
+ public WordScorerFactory buildWordScorerFactory () {
468
+ return (IndexReader reader , Terms terms , String field , double realWordLikelyhood , BytesRef separator )
469
+ -> new LaplaceScorer (reader , terms , field , realWordLikelyhood , separator , alpha );
470
+ }
452
471
}
453
472
454
473
455
- public static abstract class SmoothingModel < SM extends SmoothingModel <?>> implements NamedWriteable <SM >, ToXContent {
474
+ public static abstract class SmoothingModel implements NamedWriteable <SmoothingModel >, ToXContent {
456
475
457
476
@ Override
458
477
public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
@@ -471,16 +490,18 @@ public final boolean equals(Object obj) {
471
490
return false ;
472
491
}
473
492
@ SuppressWarnings ("unchecked" )
474
- SM other = (SM ) obj ;
493
+ SmoothingModel other = (SmoothingModel ) obj ;
475
494
return doEquals (other );
476
495
}
477
496
478
- public abstract SM fromXContent (QueryParseContext parseContext ) throws IOException ;
497
+ public abstract SmoothingModel fromXContent (QueryParseContext parseContext ) throws IOException ;
498
+
499
+ public abstract WordScorerFactory buildWordScorerFactory ();
479
500
480
501
/**
481
502
* subtype specific implementation of "equals".
482
503
*/
483
- protected abstract boolean doEquals (SM other );
504
+ protected abstract boolean doEquals (SmoothingModel other );
484
505
485
506
protected abstract XContentBuilder innerToXContent (XContentBuilder builder , Params params ) throws IOException ;
486
507
}
@@ -493,7 +514,7 @@ public final boolean equals(Object obj) {
493
514
* Smoothing</a> for details.
494
515
* </p>
495
516
*/
496
- public static final class LinearInterpolation extends SmoothingModel < LinearInterpolation > {
517
+ public static final class LinearInterpolation extends SmoothingModel {
497
518
private static final String NAME = "linear" ;
498
519
static final LinearInterpolation PROTOTYPE = new LinearInterpolation (0.8 , 0.1 , 0.1 );
499
520
private final double trigramLambda ;
@@ -563,10 +584,11 @@ public LinearInterpolation readFrom(StreamInput in) throws IOException {
563
584
}
564
585
565
586
@ Override
566
- protected boolean doEquals (LinearInterpolation other ) {
567
- return Objects .equals (trigramLambda , other .trigramLambda ) &&
568
- Objects .equals (bigramLambda , other .bigramLambda ) &&
569
- Objects .equals (unigramLambda , other .unigramLambda );
587
+ protected boolean doEquals (SmoothingModel other ) {
588
+ final LinearInterpolation otherModel = (LinearInterpolation ) other ;
589
+ return Objects .equals (trigramLambda , otherModel .trigramLambda ) &&
590
+ Objects .equals (bigramLambda , otherModel .bigramLambda ) &&
591
+ Objects .equals (unigramLambda , otherModel .unigramLambda );
570
592
}
571
593
572
594
@ Override
@@ -579,35 +601,45 @@ public LinearInterpolation fromXContent(QueryParseContext parseContext) throws I
579
601
XContentParser parser = parseContext .parser ();
580
602
XContentParser .Token token ;
581
603
String fieldName = null ;
582
- final double [] lambdas = new double [3 ];
604
+ double trigramLambda = 0.0 ;
605
+ double bigramLambda = 0.0 ;
606
+ double unigramLambda = 0.0 ;
583
607
ParseFieldMatcher matcher = parseContext .parseFieldMatcher ();
584
608
while ((token = parser .nextToken ()) != Token .END_OBJECT ) {
585
609
if (token == XContentParser .Token .FIELD_NAME ) {
586
610
fieldName = parser .currentName ();
587
- }
588
- if (token .isValue ()) {
611
+ } else if (token .isValue ()) {
589
612
if (matcher .match (fieldName , TRIGRAM_FIELD )) {
590
- lambdas [ 0 ] = parser .doubleValue ();
591
- if (lambdas [ 0 ] < 0 ) {
613
+ trigramLambda = parser .doubleValue ();
614
+ if (trigramLambda < 0 ) {
592
615
throw new IllegalArgumentException ("trigram_lambda must be positive" );
593
616
}
594
617
} else if (matcher .match (fieldName , BIGRAM_FIELD )) {
595
- lambdas [ 1 ] = parser .doubleValue ();
596
- if (lambdas [ 1 ] < 0 ) {
618
+ bigramLambda = parser .doubleValue ();
619
+ if (bigramLambda < 0 ) {
597
620
throw new IllegalArgumentException ("bigram_lambda must be positive" );
598
621
}
599
622
} else if (matcher .match (fieldName , UNIGRAM_FIELD )) {
600
- lambdas [ 2 ] = parser .doubleValue ();
601
- if (lambdas [ 2 ] < 0 ) {
623
+ unigramLambda = parser .doubleValue ();
624
+ if (unigramLambda < 0 ) {
602
625
throw new IllegalArgumentException ("unigram_lambda must be positive" );
603
626
}
604
627
} else {
605
628
throw new IllegalArgumentException (
606
629
"suggester[phrase][smoothing][linear] doesn't support field [" + fieldName + "]" );
607
630
}
631
+ } else {
632
+ throw new ParsingException (parser .getTokenLocation (), "[" + NAME + "] unknown token [" + token + "] after [" + fieldName + "]" );
608
633
}
609
634
}
610
- return new LinearInterpolation (lambdas [0 ], lambdas [1 ], lambdas [2 ]);
635
+ return new LinearInterpolation (trigramLambda , bigramLambda , unigramLambda );
636
+ }
637
+
638
+ @ Override
639
+ public WordScorerFactory buildWordScorerFactory () {
640
+ return (IndexReader reader , Terms terms , String field , double realWordLikelyhood , BytesRef separator ) ->
641
+ new LinearInterpoatingScorer (reader , terms , field , realWordLikelyhood , separator , trigramLambda , bigramLambda ,
642
+ unigramLambda );
611
643
}
612
644
}
613
645
0 commit comments