Skip to content

Commit aefdee1

Browse files
committed
Adding builder method to SmoothingModel implementations
Adds a method that emits a WordScorerFactory to all of the three SmoothingModel implementatins that will be needed when we switch to parsing the PhraseSuggestion on the coordinating node and need to delay creating the WordScorer on the shards.
1 parent 513f4e6 commit aefdee1

File tree

8 files changed

+185
-62
lines changed

8 files changed

+185
-62
lines changed

core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,26 @@
2727
import java.io.IOException;
2828
//TODO public for tests
2929
public final class LaplaceScorer extends WordScorer {
30-
30+
3131
public static final WordScorerFactory FACTORY = new WordScorer.WordScorerFactory() {
3232
@Override
3333
public WordScorer newScorer(IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) throws IOException {
3434
return new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, 0.5);
3535
}
3636
};
37-
37+
3838
private double alpha;
3939

4040
public LaplaceScorer(IndexReader reader, Terms terms, String field,
4141
double realWordLikelyhood, BytesRef separator, double alpha) throws IOException {
4242
super(reader, terms, field, realWordLikelyhood, separator);
4343
this.alpha = alpha;
4444
}
45-
45+
46+
double alpha() {
47+
return this.alpha;
48+
}
49+
4650
@Override
4751
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
4852
SuggestUtils.join(separator, spare, w_1.term, word.term);

core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,19 @@ public LinearInterpoatingScorer(IndexReader reader, Terms terms, String field,
4141
this.bigramLambda = bigramLambda / sum;
4242
this.trigramLambda = trigramLambda / sum;
4343
}
44-
44+
45+
double trigramLambda() {
46+
return this.trigramLambda;
47+
}
48+
49+
double bigramLambda() {
50+
return this.bigramLambda;
51+
}
52+
53+
double unigramLambda() {
54+
return this.unigramLambda;
55+
}
56+
4557
@Override
4658
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
4759
SuggestUtils.join(separator, spare, w_1.term, word.term);

core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818
*/
1919
package org.elasticsearch.search.suggest.phrase;
2020

21+
import org.apache.lucene.index.IndexReader;
22+
import org.apache.lucene.index.Terms;
23+
import org.apache.lucene.util.BytesRef;
2124
import org.elasticsearch.common.ParseField;
2225
import org.elasticsearch.common.ParseFieldMatcher;
26+
import org.elasticsearch.common.ParsingException;
2327
import org.elasticsearch.common.io.stream.NamedWriteable;
2428
import org.elasticsearch.common.io.stream.StreamInput;
2529
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -30,6 +34,7 @@
3034
import org.elasticsearch.index.query.QueryParseContext;
3135
import org.elasticsearch.script.Template;
3236
import org.elasticsearch.search.suggest.SuggestBuilder.SuggestionBuilder;
37+
import org.elasticsearch.search.suggest.phrase.WordScorer.WordScorerFactory;
3338

3439
import java.io.IOException;
3540
import java.util.ArrayList;
@@ -50,7 +55,7 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
5055
private Float confidence;
5156
private final Map<String, List<CandidateGenerator>> generators = new HashMap<>();
5257
private Integer gramSize;
53-
private SmoothingModel<?> model;
58+
private SmoothingModel model;
5459
private Boolean forceUnigrams;
5560
private Integer tokenLimit;
5661
private String preTag;
@@ -159,7 +164,7 @@ public PhraseSuggestionBuilder forceUnigrams(boolean forceUnigrams) {
159164
* Sets an explicit smoothing model used for this suggester. The default is
160165
* {@link PhraseSuggestionBuilder.StupidBackoff}.
161166
*/
162-
public PhraseSuggestionBuilder smoothingModel(SmoothingModel<?> model) {
167+
public PhraseSuggestionBuilder smoothingModel(SmoothingModel model) {
163168
this.model = model;
164169
return this;
165170
}
@@ -292,7 +297,7 @@ public static DirectCandidateGenerator candidateGenerator(String field) {
292297
* Smoothing</a> for details.
293298
* </p>
294299
*/
295-
public static final class StupidBackoff extends SmoothingModel<StupidBackoff> {
300+
public static final class StupidBackoff extends SmoothingModel {
296301
/**
297302
* Default discount parameter for {@link StupidBackoff} smoothing
298303
*/
@@ -341,8 +346,9 @@ public StupidBackoff readFrom(StreamInput in) throws IOException {
341346
}
342347

343348
@Override
344-
protected boolean doEquals(StupidBackoff other) {
345-
return Objects.equals(discount, other.discount);
349+
protected boolean doEquals(SmoothingModel other) {
350+
StupidBackoff otherModel = (StupidBackoff) other;
351+
return Objects.equals(discount, otherModel.discount);
346352
}
347353

348354
@Override
@@ -351,7 +357,7 @@ public final int hashCode() {
351357
}
352358

353359
@Override
354-
public StupidBackoff fromXContent(QueryParseContext parseContext) throws IOException {
360+
public SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException {
355361
XContentParser parser = parseContext.parser();
356362
XContentParser.Token token;
357363
String fieldName = null;
@@ -366,6 +372,12 @@ public StupidBackoff fromXContent(QueryParseContext parseContext) throws IOExcep
366372
}
367373
return new StupidBackoff(discount);
368374
}
375+
376+
@Override
377+
public WordScorerFactory buildWordScorerFactory() {
378+
return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator)
379+
-> new StupidBackoffScorer(reader, terms, field, realWordLikelyhood, separator, discount);
380+
}
369381
}
370382

371383
/**
@@ -377,7 +389,7 @@ public StupidBackoff fromXContent(QueryParseContext parseContext) throws IOExcep
377389
* Smoothing</a> for details.
378390
* </p>
379391
*/
380-
public static final class Laplace extends SmoothingModel<Laplace> {
392+
public static final class Laplace extends SmoothingModel {
381393
private double alpha = DEFAULT_LAPLACE_ALPHA;
382394
private static final String NAME = "laplace";
383395
private static final ParseField ALPHA_FIELD = new ParseField("alpha");
@@ -419,13 +431,14 @@ public void writeTo(StreamOutput out) throws IOException {
419431
}
420432

421433
@Override
422-
public Laplace readFrom(StreamInput in) throws IOException {
434+
public SmoothingModel readFrom(StreamInput in) throws IOException {
423435
return new Laplace(in.readDouble());
424436
}
425437

426438
@Override
427-
protected boolean doEquals(Laplace other) {
428-
return Objects.equals(alpha, other.alpha);
439+
protected boolean doEquals(SmoothingModel other) {
440+
Laplace otherModel = (Laplace) other;
441+
return Objects.equals(alpha, otherModel.alpha);
429442
}
430443

431444
@Override
@@ -434,7 +447,7 @@ public final int hashCode() {
434447
}
435448

436449
@Override
437-
public Laplace fromXContent(QueryParseContext parseContext) throws IOException {
450+
public SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException {
438451
XContentParser parser = parseContext.parser();
439452
XContentParser.Token token;
440453
String fieldName = null;
@@ -449,10 +462,16 @@ public Laplace fromXContent(QueryParseContext parseContext) throws IOException {
449462
}
450463
return new Laplace(alpha);
451464
}
465+
466+
@Override
467+
public WordScorerFactory buildWordScorerFactory() {
468+
return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator)
469+
-> new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, alpha);
470+
}
452471
}
453472

454473

455-
public static abstract class SmoothingModel<SM extends SmoothingModel<?>> implements NamedWriteable<SM>, ToXContent {
474+
public static abstract class SmoothingModel implements NamedWriteable<SmoothingModel>, ToXContent {
456475

457476
@Override
458477
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
@@ -471,16 +490,18 @@ public final boolean equals(Object obj) {
471490
return false;
472491
}
473492
@SuppressWarnings("unchecked")
474-
SM other = (SM) obj;
493+
SmoothingModel other = (SmoothingModel) obj;
475494
return doEquals(other);
476495
}
477496

478-
public abstract SM fromXContent(QueryParseContext parseContext) throws IOException;
497+
public abstract SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException;
498+
499+
public abstract WordScorerFactory buildWordScorerFactory();
479500

480501
/**
481502
* subtype specific implementation of "equals".
482503
*/
483-
protected abstract boolean doEquals(SM other);
504+
protected abstract boolean doEquals(SmoothingModel other);
484505

485506
protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException;
486507
}
@@ -493,7 +514,7 @@ public final boolean equals(Object obj) {
493514
* Smoothing</a> for details.
494515
* </p>
495516
*/
496-
public static final class LinearInterpolation extends SmoothingModel<LinearInterpolation> {
517+
public static final class LinearInterpolation extends SmoothingModel {
497518
private static final String NAME = "linear";
498519
static final LinearInterpolation PROTOTYPE = new LinearInterpolation(0.8, 0.1, 0.1);
499520
private final double trigramLambda;
@@ -563,10 +584,11 @@ public LinearInterpolation readFrom(StreamInput in) throws IOException {
563584
}
564585

565586
@Override
566-
protected boolean doEquals(LinearInterpolation other) {
567-
return Objects.equals(trigramLambda, other.trigramLambda) &&
568-
Objects.equals(bigramLambda, other.bigramLambda) &&
569-
Objects.equals(unigramLambda, other.unigramLambda);
587+
protected boolean doEquals(SmoothingModel other) {
588+
final LinearInterpolation otherModel = (LinearInterpolation) other;
589+
return Objects.equals(trigramLambda, otherModel.trigramLambda) &&
590+
Objects.equals(bigramLambda, otherModel.bigramLambda) &&
591+
Objects.equals(unigramLambda, otherModel.unigramLambda);
570592
}
571593

572594
@Override
@@ -579,35 +601,45 @@ public LinearInterpolation fromXContent(QueryParseContext parseContext) throws I
579601
XContentParser parser = parseContext.parser();
580602
XContentParser.Token token;
581603
String fieldName = null;
582-
final double[] lambdas = new double[3];
604+
double trigramLambda = 0.0;
605+
double bigramLambda = 0.0;
606+
double unigramLambda = 0.0;
583607
ParseFieldMatcher matcher = parseContext.parseFieldMatcher();
584608
while ((token = parser.nextToken()) != Token.END_OBJECT) {
585609
if (token == XContentParser.Token.FIELD_NAME) {
586610
fieldName = parser.currentName();
587-
}
588-
if (token.isValue()) {
611+
} else if (token.isValue()) {
589612
if (matcher.match(fieldName, TRIGRAM_FIELD)) {
590-
lambdas[0] = parser.doubleValue();
591-
if (lambdas[0] < 0) {
613+
trigramLambda = parser.doubleValue();
614+
if (trigramLambda < 0) {
592615
throw new IllegalArgumentException("trigram_lambda must be positive");
593616
}
594617
} else if (matcher.match(fieldName, BIGRAM_FIELD)) {
595-
lambdas[1] = parser.doubleValue();
596-
if (lambdas[1] < 0) {
618+
bigramLambda = parser.doubleValue();
619+
if (bigramLambda < 0) {
597620
throw new IllegalArgumentException("bigram_lambda must be positive");
598621
}
599622
} else if (matcher.match(fieldName, UNIGRAM_FIELD)) {
600-
lambdas[2] = parser.doubleValue();
601-
if (lambdas[2] < 0) {
623+
unigramLambda = parser.doubleValue();
624+
if (unigramLambda < 0) {
602625
throw new IllegalArgumentException("unigram_lambda must be positive");
603626
}
604627
} else {
605628
throw new IllegalArgumentException(
606629
"suggester[phrase][smoothing][linear] doesn't support field [" + fieldName + "]");
607630
}
631+
} else {
632+
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "] after [" + fieldName + "]");
608633
}
609634
}
610-
return new LinearInterpolation(lambdas[0], lambdas[1], lambdas[2]);
635+
return new LinearInterpolation(trigramLambda, bigramLambda, unigramLambda);
636+
}
637+
638+
@Override
639+
public WordScorerFactory buildWordScorerFactory() {
640+
return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) ->
641+
new LinearInterpoatingScorer(reader, terms, field, realWordLikelyhood, separator, trigramLambda, bigramLambda,
642+
unigramLambda);
611643
}
612644
}
613645

core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ public StupidBackoffScorer(IndexReader reader, Terms terms,String field, double
4242
this.discount = discount;
4343
}
4444

45+
double discount() {
46+
return this.discount;
47+
}
48+
4549
@Override
4650
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
4751
SuggestUtils.join(separator, spare, w_1.term, word.term);

core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,30 @@
2020
package org.elasticsearch.search.suggest.phrase;
2121

2222
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace;
23+
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
2324

24-
public class LaplaceModelTests extends SmoothingModelTest<Laplace> {
25+
import static org.hamcrest.Matchers.instanceOf;
26+
27+
public class LaplaceModelTests extends SmoothingModelTestCase {
2528

2629
@Override
27-
protected Laplace createTestModel() {
30+
protected SmoothingModel createTestModel() {
2831
return new Laplace(randomDoubleBetween(0.0, 10.0, false));
2932
}
3033

3134
/**
3235
* mutate the given model so the returned smoothing model is different
3336
*/
3437
@Override
35-
protected Laplace createMutation(Laplace original) {
38+
protected Laplace createMutation(SmoothingModel input) {
39+
Laplace original = (Laplace) input;
3640
return new Laplace(original.getAlpha() + 0.1);
3741
}
42+
43+
@Override
44+
void assertWordScorer(WordScorer wordScorer, SmoothingModel input) {
45+
Laplace model = (Laplace) input;
46+
assertThat(wordScorer, instanceOf(LaplaceScorer.class));
47+
assertEquals(model.getAlpha(), ((LaplaceScorer) wordScorer).alpha(), Double.MIN_VALUE);
48+
}
3849
}

core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@
2020
package org.elasticsearch.search.suggest.phrase;
2121

2222
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation;
23+
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel;
2324

24-
public class LinearInterpolationModelTests extends SmoothingModelTest<LinearInterpolation> {
25+
import static org.hamcrest.Matchers.instanceOf;
26+
27+
public class LinearInterpolationModelTests extends SmoothingModelTestCase {
2528

2629
@Override
27-
protected LinearInterpolation createTestModel() {
30+
protected SmoothingModel createTestModel() {
2831
double trigramLambda = randomDoubleBetween(0.0, 10.0, false);
2932
double bigramLambda = randomDoubleBetween(0.0, 10.0, false);
3033
double unigramLambda = randomDoubleBetween(0.0, 10.0, false);
31-
// normalize
34+
// normalize so parameters sum to 1
3235
double sum = trigramLambda + bigramLambda + unigramLambda;
3336
return new LinearInterpolation(trigramLambda / sum, bigramLambda / sum, unigramLambda / sum);
3437
}
@@ -37,7 +40,8 @@ protected LinearInterpolation createTestModel() {
3740
* mutate the given model so the returned smoothing model is different
3841
*/
3942
@Override
40-
protected LinearInterpolation createMutation(LinearInterpolation original) {
43+
protected LinearInterpolation createMutation(SmoothingModel input) {
44+
LinearInterpolation original = (LinearInterpolation) input;
4145
// swap two values permute original lambda values
4246
switch (randomIntBetween(0, 2)) {
4347
case 0:
@@ -52,4 +56,14 @@ protected LinearInterpolation createMutation(LinearInterpolation original) {
5256
return new LinearInterpolation(original.getUnigramLambda(), original.getBigramLambda(), original.getTrigramLambda());
5357
}
5458
}
59+
60+
@Override
61+
void assertWordScorer(WordScorer wordScorer, SmoothingModel in) {
62+
LinearInterpolation testModel = (LinearInterpolation) in;
63+
LinearInterpoatingScorer testScorer = (LinearInterpoatingScorer) wordScorer;
64+
assertThat(wordScorer, instanceOf(LinearInterpoatingScorer.class));
65+
assertEquals(testModel.getTrigramLambda(), (testScorer).trigramLambda(), 1e-15);
66+
assertEquals(testModel.getBigramLambda(), (testScorer).bigramLambda(), 1e-15);
67+
assertEquals(testModel.getUnigramLambda(), (testScorer).unigramLambda(), 1e-15);
68+
}
5569
}

0 commit comments

Comments
 (0)