diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java index 04d98c3827d91..678f3082bac9a 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LaplaceScorer.java @@ -27,14 +27,14 @@ import java.io.IOException; //TODO public for tests public final class LaplaceScorer extends WordScorer { - + public static final WordScorerFactory FACTORY = new WordScorer.WordScorerFactory() { @Override public WordScorer newScorer(IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) throws IOException { return new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, 0.5); } }; - + private double alpha; public LaplaceScorer(IndexReader reader, Terms terms, String field, @@ -42,7 +42,11 @@ public LaplaceScorer(IndexReader reader, Terms terms, String field, super(reader, terms, field, realWordLikelyhood, separator); this.alpha = alpha; } - + + double alpha() { + return this.alpha; + } + @Override protected double scoreBigram(Candidate word, Candidate w_1) throws IOException { SuggestUtils.join(separator, spare, w_1.term, word.term); diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java index d2b1ba48b1360..368d461fc5334 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/LinearInterpoatingScorer.java @@ -41,7 +41,19 @@ public LinearInterpoatingScorer(IndexReader reader, Terms terms, String field, this.bigramLambda = bigramLambda / sum; this.trigramLambda = trigramLambda / sum; } - + + double trigramLambda() { + return this.trigramLambda; + } + + double bigramLambda() { + return this.bigramLambda; + } + + double unigramLambda() { + return this.unigramLambda; + } + @Override protected double scoreBigram(Candidate word, Candidate w_1) throws IOException { SuggestUtils.join(separator, spare, w_1.term, word.term); diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java index 0b904a95720d2..c226d061047d1 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestParser.java @@ -36,6 +36,8 @@ import org.elasticsearch.search.suggest.SuggestContextParser; import org.elasticsearch.search.suggest.SuggestUtils; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff; import org.elasticsearch.search.suggest.phrase.PhraseSuggestionContext.DirectCandidateGenerator; import java.io.IOException; @@ -265,7 +267,7 @@ public WordScorer newScorer(IndexReader reader, Terms terms, String field, doubl }); } else if ("laplace".equals(fieldName)) { ensureNoSmoothing(suggestion); - double theAlpha = 0.5; + double theAlpha = Laplace.DEFAULT_LAPLACE_ALPHA; while ((token = parser.nextToken()) != Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -286,7 +288,7 @@ public WordScorer newScorer(IndexReader reader, Terms terms, String field, doubl } else if ("stupid_backoff".equals(fieldName) || "stupidBackoff".equals(fieldName)) { ensureNoSmoothing(suggestion); - double theDiscount = 0.4; + double theDiscount = StupidBackoff.DEFAULT_BACKOFF_DISCOUNT; while ((token = parser.nextToken()) != Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { fieldName = parser.currentName(); diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java index 1055fbe83fce8..0e1fec6c7b281 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/PhraseSuggestionBuilder.java @@ -18,10 +18,23 @@ */ package org.elasticsearch.search.suggest.phrase; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.Terms; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParser.Token; +import org.elasticsearch.index.query.QueryParseContext; import org.elasticsearch.script.Template; import org.elasticsearch.search.suggest.SuggestBuilder.SuggestionBuilder; +import org.elasticsearch.search.suggest.phrase.WordScorer.WordScorerFactory; import java.io.IOException; import java.util.ArrayList; @@ -29,6 +42,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.Set; /** @@ -284,7 +298,14 @@ public static DirectCandidateGenerator candidateGenerator(String field) { *

*/ public static final class StupidBackoff extends SmoothingModel { - private final double discount; + /** + * Default discount parameter for {@link StupidBackoff} smoothing + */ + public static final double DEFAULT_BACKOFF_DISCOUNT = 0.4; + private double discount = DEFAULT_BACKOFF_DISCOUNT; + static final StupidBackoff PROTOTYPE = new StupidBackoff(DEFAULT_BACKOFF_DISCOUNT); + private static final String NAME = "stupid_backoff"; + private static final ParseField DISCOUNT_FIELD = new ParseField("discount"); /** * Creates a Stupid-Backoff smoothing model. @@ -293,15 +314,70 @@ public static final class StupidBackoff extends SmoothingModel { * the discount given to lower order ngrams if the higher order ngram doesn't exits */ public StupidBackoff(double discount) { - super("stupid_backoff"); this.discount = discount; } + /** + * @return the discount parameter of the model + */ + public double getDiscount() { + return this.discount; + } + @Override protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("discount", discount); + builder.field(DISCOUNT_FIELD.getPreferredName(), discount); return builder; } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(discount); + } + + @Override + public StupidBackoff readFrom(StreamInput in) throws IOException { + return new StupidBackoff(in.readDouble()); + } + + @Override + protected boolean doEquals(SmoothingModel other) { + StupidBackoff otherModel = (StupidBackoff) other; + return Objects.equals(discount, otherModel.discount); + } + + @Override + public final int hashCode() { + return Objects.hash(discount); + } + + @Override + public SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException { + XContentParser parser = parseContext.parser(); + XContentParser.Token token; + String fieldName = null; + double discount = DEFAULT_BACKOFF_DISCOUNT; + while ((token = parser.nextToken()) != Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + fieldName = parser.currentName(); + } + if (token.isValue() && parseContext.parseFieldMatcher().match(fieldName, DISCOUNT_FIELD)) { + discount = parser.doubleValue(); + } + } + return new StupidBackoff(discount); + } + + @Override + public WordScorerFactory buildWordScorerFactory() { + return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) + -> new StupidBackoffScorer(reader, terms, field, realWordLikelyhood, separator, discount); + } } /** @@ -314,39 +390,119 @@ protected XContentBuilder innerToXContent(XContentBuilder builder, Params params *

*/ public static final class Laplace extends SmoothingModel { - private final double alpha; + private double alpha = DEFAULT_LAPLACE_ALPHA; + private static final String NAME = "laplace"; + private static final ParseField ALPHA_FIELD = new ParseField("alpha"); + /** + * Default alpha parameter for laplace smoothing + */ + public static final double DEFAULT_LAPLACE_ALPHA = 0.5; + static final Laplace PROTOTYPE = new Laplace(DEFAULT_LAPLACE_ALPHA); + /** * Creates a Laplace smoothing model. * */ public Laplace(double alpha) { - super("laplace"); this.alpha = alpha; } + /** + * @return the laplace model alpha parameter + */ + public double getAlpha() { + return this.alpha; + } + @Override protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("alpha", alpha); + builder.field(ALPHA_FIELD.getPreferredName(), alpha); return builder; } - } + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(alpha); + } - public static abstract class SmoothingModel implements ToXContent { - private final String type; + @Override + public SmoothingModel readFrom(StreamInput in) throws IOException { + return new Laplace(in.readDouble()); + } - protected SmoothingModel(String type) { - this.type = type; + @Override + protected boolean doEquals(SmoothingModel other) { + Laplace otherModel = (Laplace) other; + return Objects.equals(alpha, otherModel.alpha); + } + + @Override + public final int hashCode() { + return Objects.hash(alpha); } + @Override + public SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException { + XContentParser parser = parseContext.parser(); + XContentParser.Token token; + String fieldName = null; + double alpha = DEFAULT_LAPLACE_ALPHA; + while ((token = parser.nextToken()) != Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + fieldName = parser.currentName(); + } + if (token.isValue() && parseContext.parseFieldMatcher().match(fieldName, ALPHA_FIELD)) { + alpha = parser.doubleValue(); + } + } + return new Laplace(alpha); + } + + @Override + public WordScorerFactory buildWordScorerFactory() { + return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) + -> new LaplaceScorer(reader, terms, field, realWordLikelyhood, separator, alpha); + } + } + + + public static abstract class SmoothingModel implements NamedWriteable, ToXContent { + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(type); + builder.startObject(getWriteableName()); innerToXContent(builder,params); builder.endObject(); return builder; } + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + @SuppressWarnings("unchecked") + SmoothingModel other = (SmoothingModel) obj; + return doEquals(other); + } + + public abstract SmoothingModel fromXContent(QueryParseContext parseContext) throws IOException; + + public abstract WordScorerFactory buildWordScorerFactory(); + + /** + * subtype specific implementation of "equals". + */ + protected abstract boolean doEquals(SmoothingModel other); + protected abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException; } @@ -359,9 +515,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws *

*/ public static final class LinearInterpolation extends SmoothingModel { + private static final String NAME = "linear"; + static final LinearInterpolation PROTOTYPE = new LinearInterpolation(0.8, 0.1, 0.1); private final double trigramLambda; private final double bigramLambda; private final double unigramLambda; + private static final ParseField TRIGRAM_FIELD = new ParseField("trigram_lambda"); + private static final ParseField BIGRAM_FIELD = new ParseField("bigram_lambda"); + private static final ParseField UNIGRAM_FIELD = new ParseField("unigram_lambda"); /** * Creates a linear interpolation smoothing model. @@ -376,19 +537,110 @@ public static final class LinearInterpolation extends SmoothingModel { * the unigram lambda */ public LinearInterpolation(double trigramLambda, double bigramLambda, double unigramLambda) { - super("linear"); + double sum = trigramLambda + bigramLambda + unigramLambda; + if (Math.abs(sum - 1.0) > 0.001) { + throw new IllegalArgumentException("linear smoothing lambdas must sum to 1"); + } this.trigramLambda = trigramLambda; this.bigramLambda = bigramLambda; this.unigramLambda = unigramLambda; } + public double getTrigramLambda() { + return this.trigramLambda; + } + + public double getBigramLambda() { + return this.bigramLambda; + } + + public double getUnigramLambda() { + return this.unigramLambda; + } + @Override protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("trigram_lambda", trigramLambda); - builder.field("bigram_lambda", bigramLambda); - builder.field("unigram_lambda", unigramLambda); + builder.field(TRIGRAM_FIELD.getPreferredName(), trigramLambda); + builder.field(BIGRAM_FIELD.getPreferredName(), bigramLambda); + builder.field(UNIGRAM_FIELD.getPreferredName(), unigramLambda); return builder; } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(trigramLambda); + out.writeDouble(bigramLambda); + out.writeDouble(unigramLambda); + } + + @Override + public LinearInterpolation readFrom(StreamInput in) throws IOException { + return new LinearInterpolation(in.readDouble(), in.readDouble(), in.readDouble()); + } + + @Override + protected boolean doEquals(SmoothingModel other) { + final LinearInterpolation otherModel = (LinearInterpolation) other; + return Objects.equals(trigramLambda, otherModel.trigramLambda) && + Objects.equals(bigramLambda, otherModel.bigramLambda) && + Objects.equals(unigramLambda, otherModel.unigramLambda); + } + + @Override + public final int hashCode() { + return Objects.hash(trigramLambda, bigramLambda, unigramLambda); + } + + @Override + public LinearInterpolation fromXContent(QueryParseContext parseContext) throws IOException { + XContentParser parser = parseContext.parser(); + XContentParser.Token token; + String fieldName = null; + double trigramLambda = 0.0; + double bigramLambda = 0.0; + double unigramLambda = 0.0; + ParseFieldMatcher matcher = parseContext.parseFieldMatcher(); + while ((token = parser.nextToken()) != Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + fieldName = parser.currentName(); + } else if (token.isValue()) { + if (matcher.match(fieldName, TRIGRAM_FIELD)) { + trigramLambda = parser.doubleValue(); + if (trigramLambda < 0) { + throw new IllegalArgumentException("trigram_lambda must be positive"); + } + } else if (matcher.match(fieldName, BIGRAM_FIELD)) { + bigramLambda = parser.doubleValue(); + if (bigramLambda < 0) { + throw new IllegalArgumentException("bigram_lambda must be positive"); + } + } else if (matcher.match(fieldName, UNIGRAM_FIELD)) { + unigramLambda = parser.doubleValue(); + if (unigramLambda < 0) { + throw new IllegalArgumentException("unigram_lambda must be positive"); + } + } else { + throw new IllegalArgumentException( + "suggester[phrase][smoothing][linear] doesn't support field [" + fieldName + "]"); + } + } else { + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "] after [" + fieldName + "]"); + } + } + return new LinearInterpolation(trigramLambda, bigramLambda, unigramLambda); + } + + @Override + public WordScorerFactory buildWordScorerFactory() { + return (IndexReader reader, Terms terms, String field, double realWordLikelyhood, BytesRef separator) -> + new LinearInterpoatingScorer(reader, terms, field, realWordLikelyhood, separator, trigramLambda, bigramLambda, + unigramLambda); + } } /** @@ -428,7 +680,7 @@ public static final class DirectCandidateGenerator extends CandidateGenerator { private Float minDocFreq; /** - * @param field Sets from what field to fetch the candidate suggestions from. + * @param field Sets from what field to fetch the candidate suggestions from. */ public DirectCandidateGenerator(String field) { super("direct_generator"); diff --git a/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java b/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java index fcf6064d2286b..5bd3d942b1afd 100644 --- a/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java +++ b/core/src/main/java/org/elasticsearch/search/suggest/phrase/StupidBackoffScorer.java @@ -42,6 +42,10 @@ public StupidBackoffScorer(IndexReader reader, Terms terms,String field, double this.discount = discount; } + double discount() { + return this.discount; + } + @Override protected double scoreBigram(Candidate word, Candidate w_1) throws IOException { SuggestUtils.join(separator, spare, w_1.term, word.term); diff --git a/core/src/test/java/org/elasticsearch/common/geo/builders/AbstractShapeBuilderTestCase.java b/core/src/test/java/org/elasticsearch/common/geo/builders/AbstractShapeBuilderTestCase.java index 279e31aadd44f..9311db44da00a 100644 --- a/core/src/test/java/org/elasticsearch/common/geo/builders/AbstractShapeBuilderTestCase.java +++ b/core/src/test/java/org/elasticsearch/common/geo/builders/AbstractShapeBuilderTestCase.java @@ -89,7 +89,6 @@ public void testFromXContent() throws IOException { } XContentBuilder builder = testShape.toXContent(contentBuilder, ToXContent.EMPTY_PARAMS); XContentParser shapeParser = XContentHelper.createParser(builder.bytes()); - XContentHelper.createParser(builder.bytes()); shapeParser.nextToken(); ShapeBuilder parsedShape = ShapeBuilder.parse(shapeParser); assertNotSame(testShape, parsedShape); diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java new file mode 100644 index 0000000000000..87ad654e0cdc9 --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LaplaceModelTests.java @@ -0,0 +1,49 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; + +import static org.hamcrest.Matchers.instanceOf; + +public class LaplaceModelTests extends SmoothingModelTestCase { + + @Override + protected SmoothingModel createTestModel() { + return new Laplace(randomDoubleBetween(0.0, 10.0, false)); + } + + /** + * mutate the given model so the returned smoothing model is different + */ + @Override + protected Laplace createMutation(SmoothingModel input) { + Laplace original = (Laplace) input; + return new Laplace(original.getAlpha() + 0.1); + } + + @Override + void assertWordScorer(WordScorer wordScorer, SmoothingModel input) { + Laplace model = (Laplace) input; + assertThat(wordScorer, instanceOf(LaplaceScorer.class)); + assertEquals(model.getAlpha(), ((LaplaceScorer) wordScorer).alpha(), Double.MIN_VALUE); + } +} diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java new file mode 100644 index 0000000000000..1112b7a5ed7d3 --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/LinearInterpolationModelTests.java @@ -0,0 +1,69 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; + +import static org.hamcrest.Matchers.instanceOf; + +public class LinearInterpolationModelTests extends SmoothingModelTestCase { + + @Override + protected SmoothingModel createTestModel() { + double trigramLambda = randomDoubleBetween(0.0, 10.0, false); + double bigramLambda = randomDoubleBetween(0.0, 10.0, false); + double unigramLambda = randomDoubleBetween(0.0, 10.0, false); + // normalize so parameters sum to 1 + double sum = trigramLambda + bigramLambda + unigramLambda; + return new LinearInterpolation(trigramLambda / sum, bigramLambda / sum, unigramLambda / sum); + } + + /** + * mutate the given model so the returned smoothing model is different + */ + @Override + protected LinearInterpolation createMutation(SmoothingModel input) { + LinearInterpolation original = (LinearInterpolation) input; + // swap two values permute original lambda values + switch (randomIntBetween(0, 2)) { + case 0: + // swap first two + return new LinearInterpolation(original.getBigramLambda(), original.getTrigramLambda(), original.getUnigramLambda()); + case 1: + // swap last two + return new LinearInterpolation(original.getTrigramLambda(), original.getUnigramLambda(), original.getBigramLambda()); + case 2: + default: + // swap first and last + return new LinearInterpolation(original.getUnigramLambda(), original.getBigramLambda(), original.getTrigramLambda()); + } + } + + @Override + void assertWordScorer(WordScorer wordScorer, SmoothingModel in) { + LinearInterpolation testModel = (LinearInterpolation) in; + LinearInterpoatingScorer testScorer = (LinearInterpoatingScorer) wordScorer; + assertThat(wordScorer, instanceOf(LinearInterpoatingScorer.class)); + assertEquals(testModel.getTrigramLambda(), (testScorer).trigramLambda(), 1e-15); + assertEquals(testModel.getBigramLambda(), (testScorer).bigramLambda(), 1e-15); + assertEquals(testModel.getUnigramLambda(), (testScorer).unigramLambda(), 1e-15); + } +} diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/SmoothingModelTestCase.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/SmoothingModelTestCase.java new file mode 100644 index 0000000000000..e4a8ae72b911f --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/SmoothingModelTestCase.java @@ -0,0 +1,196 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.core.WhitespaceAnalyzer; +import org.apache.lucene.analysis.miscellaneous.PerFieldAnalyzerWrapper; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.MultiFields; +import org.apache.lucene.store.RAMDirectory; +import org.elasticsearch.common.ParseFieldMatcher; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.QueryParseContext; +import org.elasticsearch.indices.query.IndicesQueriesRegistry; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.Laplace; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.LinearInterpolation; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff; +import org.elasticsearch.test.ESTestCase; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public abstract class SmoothingModelTestCase extends ESTestCase { + + private static NamedWriteableRegistry namedWriteableRegistry; + + /** + * setup for the whole base test class + */ + @BeforeClass + public static void init() { + if (namedWriteableRegistry == null) { + namedWriteableRegistry = new NamedWriteableRegistry(); + namedWriteableRegistry.registerPrototype(SmoothingModel.class, Laplace.PROTOTYPE); + namedWriteableRegistry.registerPrototype(SmoothingModel.class, LinearInterpolation.PROTOTYPE); + namedWriteableRegistry.registerPrototype(SmoothingModel.class, StupidBackoff.PROTOTYPE); + } + } + + @AfterClass + public static void afterClass() throws Exception { + namedWriteableRegistry = null; + } + + /** + * create random model that is put under test + */ + protected abstract SmoothingModel createTestModel(); + + /** + * mutate the given model so the returned smoothing model is different + */ + protected abstract SmoothingModel createMutation(SmoothingModel original) throws IOException; + + /** + * Test that creates new smoothing model from a random test smoothing model and checks both for equality + */ + public void testFromXContent() throws IOException { + QueryParseContext context = new QueryParseContext(new IndicesQueriesRegistry(Settings.settingsBuilder().build(), Collections.emptyMap())); + context.parseFieldMatcher(new ParseFieldMatcher(Settings.EMPTY)); + + SmoothingModel testModel = createTestModel(); + XContentBuilder contentBuilder = XContentFactory.contentBuilder(randomFrom(XContentType.values())); + if (randomBoolean()) { + contentBuilder.prettyPrint(); + } + contentBuilder.startObject(); + testModel.innerToXContent(contentBuilder, ToXContent.EMPTY_PARAMS); + contentBuilder.endObject(); + XContentParser parser = XContentHelper.createParser(contentBuilder.bytes()); + context.reset(parser); + parser.nextToken(); // go to start token, real parsing would do that in the outer element parser + SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, + testModel.getWriteableName()); + SmoothingModel parsedModel = prototype.fromXContent(context); + assertNotSame(testModel, parsedModel); + assertEquals(testModel, parsedModel); + assertEquals(testModel.hashCode(), parsedModel.hashCode()); + } + + /** + * Test the WordScorer emitted by the smoothing model + */ + public void testBuildWordScorer() throws IOException { + SmoothingModel testModel = createTestModel(); + + Map mapping = new HashMap<>(); + mapping.put("field", new WhitespaceAnalyzer()); + PerFieldAnalyzerWrapper wrapper = new PerFieldAnalyzerWrapper(new WhitespaceAnalyzer(), mapping); + IndexWriter writer = new IndexWriter(new RAMDirectory(), new IndexWriterConfig(wrapper)); + Document doc = new Document(); + doc.add(new Field("field", "someText", TextField.TYPE_NOT_STORED)); + writer.addDocument(doc); + DirectoryReader ir = DirectoryReader.open(writer, false); + + WordScorer wordScorer = testModel.buildWordScorerFactory().newScorer(ir, MultiFields.getTerms(ir , "field"), "field", 0.9d, BytesRefs.toBytesRef(" ")); + assertWordScorer(wordScorer, testModel); + } + + /** + * implementation dependant assertions on the wordScorer produced by the smoothing model under test + */ + abstract void assertWordScorer(WordScorer wordScorer, SmoothingModel testModel); + + /** + * Test serialization and deserialization of the tested model. + */ + public void testSerialization() throws IOException { + SmoothingModel testModel = createTestModel(); + SmoothingModel deserializedModel = copyModel(testModel); + assertEquals(testModel, deserializedModel); + assertEquals(testModel.hashCode(), deserializedModel.hashCode()); + assertNotSame(testModel, deserializedModel); + } + + /** + * Test equality and hashCode properties + */ + @SuppressWarnings("unchecked") + public void testEqualsAndHashcode() throws IOException { + SmoothingModel firstModel = createTestModel(); + assertFalse("smoothing model is equal to null", firstModel.equals(null)); + assertFalse("smoothing model is equal to incompatible type", firstModel.equals("")); + assertTrue("smoothing model is not equal to self", firstModel.equals(firstModel)); + assertThat("same smoothing model's hashcode returns different values if called multiple times", firstModel.hashCode(), + equalTo(firstModel.hashCode())); + assertThat("different smoothing models should not be equal", createMutation(firstModel), not(equalTo(firstModel))); + + SmoothingModel secondModel = copyModel(firstModel); + assertTrue("smoothing model is not equal to self", secondModel.equals(secondModel)); + assertTrue("smoothing model is not equal to its copy", firstModel.equals(secondModel)); + assertTrue("equals is not symmetric", secondModel.equals(firstModel)); + assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(firstModel.hashCode())); + + SmoothingModel thirdModel = copyModel(secondModel); + assertTrue("smoothing model is not equal to self", thirdModel.equals(thirdModel)); + assertTrue("smoothing model is not equal to its copy", secondModel.equals(thirdModel)); + assertThat("smoothing model copy's hashcode is different from original hashcode", secondModel.hashCode(), equalTo(thirdModel.hashCode())); + assertTrue("equals is not transitive", firstModel.equals(thirdModel)); + assertThat("smoothing model copy's hashcode is different from original hashcode", firstModel.hashCode(), equalTo(thirdModel.hashCode())); + assertTrue("equals is not symmetric", thirdModel.equals(secondModel)); + assertTrue("equals is not symmetric", thirdModel.equals(firstModel)); + } + + static SmoothingModel copyModel(SmoothingModel original) throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput()) { + original.writeTo(output); + try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(output.bytes()), namedWriteableRegistry)) { + SmoothingModel prototype = (SmoothingModel) namedWriteableRegistry.getPrototype(SmoothingModel.class, original.getWriteableName()); + return prototype.readFrom(in); + } + } + } + +} diff --git a/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java new file mode 100644 index 0000000000000..c3bd66d2a815c --- /dev/null +++ b/core/src/test/java/org/elasticsearch/search/suggest/phrase/StupidBackoffModelTests.java @@ -0,0 +1,49 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.suggest.phrase; + +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.SmoothingModel; +import org.elasticsearch.search.suggest.phrase.PhraseSuggestionBuilder.StupidBackoff; + +import static org.hamcrest.Matchers.instanceOf; + +public class StupidBackoffModelTests extends SmoothingModelTestCase { + + @Override + protected SmoothingModel createTestModel() { + return new StupidBackoff(randomDoubleBetween(0.0, 10.0, false)); + } + + /** + * mutate the given model so the returned smoothing model is different + */ + @Override + protected StupidBackoff createMutation(SmoothingModel input) { + StupidBackoff original = (StupidBackoff) input; + return new StupidBackoff(original.getDiscount() + 0.1); + } + + @Override + void assertWordScorer(WordScorer wordScorer, SmoothingModel input) { + assertThat(wordScorer, instanceOf(StupidBackoffScorer.class)); + StupidBackoff testModel = (StupidBackoff) input; + assertEquals(testModel.getDiscount(), ((StupidBackoffScorer) wordScorer).discount(), Double.MIN_VALUE); + } +}