Skip to content

Commit 652ff11

Browse files
committed
Added Recall and Mean Average Precision + unit tests (Based off of the precision unit tests
1 parent c4cc68e commit 652ff11

File tree

6 files changed

+1029
-0
lines changed

6 files changed

+1029
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
package org.elasticsearch.index.rankeval;
2+
3+
import org.elasticsearch.common.ParseField;
4+
import org.elasticsearch.common.io.stream.StreamInput;
5+
import org.elasticsearch.common.io.stream.StreamOutput;
6+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
7+
import org.elasticsearch.common.xcontent.XContentBuilder;
8+
import org.elasticsearch.common.xcontent.XContentParser;
9+
import org.elasticsearch.search.SearchHit;
10+
11+
import javax.naming.directory.SearchResult;
12+
import java.io.IOException;
13+
import java.util.List;
14+
import java.util.Objects;
15+
import java.util.OptionalInt;
16+
17+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
18+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
19+
20+
public class MeanAveragePrecisionAtK implements EvaluationMetric {
21+
22+
public static final String NAME = "mean_average_precision";
23+
24+
private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold");
25+
private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled");
26+
private static final ParseField K_FIELD = new ParseField("k");
27+
28+
private static final int DEFAULT_K = 10;
29+
30+
private final boolean ignoreUnlabeled;
31+
private final int relevantRatingThreshhold;
32+
private final int k;
33+
34+
/**
35+
* Metric implementing Recall@K.
36+
* @param threshold
37+
* ratings equal or above this value will be considered relevant.
38+
* @param ignoreUnlabeled
39+
* Controls how unlabeled documents in the search hits are treated.
40+
* Set to 'true', unlabeled documents are ignored and neither count
41+
* as true or false positives. Set to 'false', they are treated as
42+
* false positives.
43+
* @param k
44+
* controls the window size for the search results the metric takes into account
45+
*/
46+
public MeanAveragePrecisionAtK(int threshold, boolean ignoreUnlabeled, int k) {
47+
if (threshold < 0) {
48+
throw new IllegalArgumentException("Relevant rating threshold for mean average precision must be positive integer.");
49+
}
50+
if (k <= 0) {
51+
throw new IllegalArgumentException("Window size k must be positive.");
52+
}
53+
this.relevantRatingThreshhold = threshold;
54+
this.ignoreUnlabeled = ignoreUnlabeled;
55+
this.k = k;
56+
}
57+
58+
public MeanAveragePrecisionAtK() {
59+
this(1, false, DEFAULT_K);
60+
}
61+
62+
private static final ConstructingObjectParser<MeanAveragePrecisionAtK, Void> PARSER = new ConstructingObjectParser<>(NAME,
63+
args -> {
64+
Integer threshHold = (Integer) args[0];
65+
Boolean ignoreUnlabeled = (Boolean) args[1];
66+
Integer k = (Integer) args[2];
67+
return new MeanAveragePrecisionAtK(threshHold == null ? 1 : threshHold,
68+
ignoreUnlabeled == null ? false : ignoreUnlabeled,
69+
k == null ? DEFAULT_K : k);
70+
});
71+
72+
static {
73+
PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD);
74+
PARSER.declareBoolean(optionalConstructorArg(), IGNORE_UNLABELED_FIELD);
75+
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
76+
}
77+
78+
MeanAveragePrecisionAtK(StreamInput in) throws IOException {
79+
relevantRatingThreshhold = in.readVInt();
80+
ignoreUnlabeled = in.readBoolean();
81+
k = in.readVInt();
82+
}
83+
84+
int getK() {
85+
return this.k;
86+
}
87+
88+
@Override
89+
public void writeTo(StreamOutput out) throws IOException {
90+
System.out.println(relevantRatingThreshhold+","+ignoreUnlabeled+","+k+" kaslfdjkalsdfjalsdkjfkalsjd ");
91+
out.writeVInt(relevantRatingThreshhold);
92+
out.writeBoolean(ignoreUnlabeled);
93+
out.writeVInt(k);
94+
}
95+
96+
@Override
97+
public String getWriteableName() {
98+
return NAME;
99+
}
100+
101+
/**
102+
* Return the rating threshold above which ratings are considered to be
103+
* "relevant" for this metric. Defaults to 1.
104+
*/
105+
public int getRelevantRatingThreshold() {
106+
return relevantRatingThreshhold;
107+
}
108+
109+
/**
110+
* Gets the 'ignore_unlabeled' parameter.
111+
*/
112+
public boolean getIgnoreUnlabeled() {
113+
return ignoreUnlabeled;
114+
}
115+
116+
@Override
117+
public OptionalInt forcedSearchSize() {
118+
return OptionalInt.of(k);
119+
}
120+
121+
public static MeanAveragePrecisionAtK fromXContent(XContentParser parser) {
122+
return PARSER.apply(parser, null);
123+
}
124+
125+
/**
126+
* Compute recallAtN based on provided relevant document IDs.
127+
*
128+
* @return recall at n for above {@link SearchResult} list.
129+
**/
130+
@Override
131+
public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
132+
List<RatedDocument> ratedDocs) {
133+
int truePositives = 0;
134+
int falsePositives = 0;
135+
List<RatedSearchHit> ratedSearchHits = EvaluationMetric.joinHitsWithRatings(hits, ratedDocs);
136+
int relevantDocs = ratedDocs.size();
137+
int numberOfRelevantDocs = ratedSearchHits.size();
138+
int currentPosition = 0;
139+
double meanAveragePrecision = 0.0;
140+
for (RatedSearchHit hit : ratedSearchHits) {
141+
currentPosition++;
142+
OptionalInt rating = hit.getRating();
143+
if (rating.isPresent()) {
144+
if (rating.getAsInt() >= this.relevantRatingThreshhold) {
145+
truePositives++;
146+
meanAveragePrecision += (double) truePositives/ (double) currentPosition;
147+
} else {
148+
falsePositives++;
149+
}
150+
} else if (ignoreUnlabeled == false) {
151+
falsePositives++;
152+
}
153+
}
154+
if (meanAveragePrecision>0.0) {
155+
meanAveragePrecision = meanAveragePrecision / (truePositives+falsePositives);
156+
}
157+
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, meanAveragePrecision);
158+
evalQueryQuality.setMetricDetails(
159+
new MeanAveragePrecisionAtK.Detail(truePositives, truePositives + falsePositives));
160+
evalQueryQuality.addHitsAndRatings(ratedSearchHits);
161+
return evalQueryQuality;
162+
}
163+
164+
@Override
165+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
166+
builder.startObject();
167+
builder.startObject(NAME);
168+
builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold);
169+
builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), this.ignoreUnlabeled);
170+
builder.field(K_FIELD.getPreferredName(), this.k);
171+
builder.endObject();
172+
builder.endObject();
173+
return builder;
174+
}
175+
176+
@Override
177+
public final boolean equals(Object obj) {
178+
if (this == obj) {
179+
return true;
180+
}
181+
if (obj == null || getClass() != obj.getClass()) {
182+
return false;
183+
}
184+
MeanAveragePrecisionAtK other = (MeanAveragePrecisionAtK) obj;
185+
return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold)
186+
&& Objects.equals(k, other.k)
187+
&& Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled);
188+
}
189+
190+
@Override
191+
public final int hashCode() {
192+
return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled, k);
193+
}
194+
195+
public static final class Detail implements MetricDetail {
196+
197+
private static final ParseField DOCS_RETRIEVED_FIELD = new ParseField("docs_retrieved");
198+
private static final ParseField RELEVANT_DOCS_RETRIEVED_FIELD = new ParseField("relevant_docs_retrieved");
199+
private int relevantRetrieved;
200+
private int retrieved;
201+
202+
Detail(int relevantRetrieved, int retrieved) {
203+
this.relevantRetrieved = relevantRetrieved;
204+
this.retrieved = retrieved;
205+
}
206+
207+
Detail(StreamInput in) throws IOException {
208+
this.relevantRetrieved = in.readVInt();
209+
this.retrieved = in.readVInt();
210+
}
211+
212+
@Override
213+
public XContentBuilder innerToXContent(XContentBuilder builder, Params params)
214+
throws IOException {
215+
builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), relevantRetrieved);
216+
builder.field(DOCS_RETRIEVED_FIELD.getPreferredName(), retrieved);
217+
return builder;
218+
}
219+
220+
private static final ConstructingObjectParser<MeanAveragePrecisionAtK.Detail, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
221+
return new MeanAveragePrecisionAtK.Detail((Integer) args[0], (Integer) args[1]);
222+
});
223+
224+
static {
225+
PARSER.declareInt(constructorArg(), RELEVANT_DOCS_RETRIEVED_FIELD);
226+
PARSER.declareInt(constructorArg(), DOCS_RETRIEVED_FIELD);
227+
}
228+
229+
public static MeanAveragePrecisionAtK.Detail fromXContent(XContentParser parser) {
230+
return PARSER.apply(parser, null);
231+
}
232+
233+
@Override
234+
public void writeTo(StreamOutput out) throws IOException {
235+
out.writeVInt(relevantRetrieved);
236+
out.writeVInt(retrieved);
237+
}
238+
239+
@Override
240+
public String getWriteableName() {
241+
return NAME;
242+
}
243+
244+
public int getRelevantRetrieved() {
245+
return relevantRetrieved;
246+
}
247+
248+
public int getRetrieved() {
249+
return retrieved;
250+
}
251+
252+
@Override
253+
public boolean equals(Object obj) {
254+
if (this == obj) {
255+
return true;
256+
}
257+
if (obj == null || getClass() != obj.getClass()) {
258+
return false;
259+
}
260+
MeanAveragePrecisionAtK.Detail other = (MeanAveragePrecisionAtK.Detail) obj;
261+
return Objects.equals(relevantRetrieved, other.relevantRetrieved)
262+
&& Objects.equals(retrieved, other.retrieved);
263+
}
264+
265+
@Override
266+
public int hashCode() {
267+
return Objects.hash(relevantRetrieved, retrieved);
268+
}
269+
}
270+
271+
}

modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java

+8
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
3333
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
3434
namedXContent.add(
3535
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK::fromXContent));
36+
namedXContent.add(
37+
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallAtK.NAME), RecallAtK::fromXContent));
38+
namedXContent.add(
39+
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(MeanAveragePrecisionAtK.NAME), MeanAveragePrecisionAtK::fromXContent));
3640
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(MeanReciprocalRank.NAME),
3741
MeanReciprocalRank::fromXContent));
3842
namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME),
@@ -42,6 +46,10 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
4246

4347
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(PrecisionAtK.NAME),
4448
PrecisionAtK.Detail::fromXContent));
49+
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(RecallAtK.NAME),
50+
RecallAtK.Detail::fromXContent));
51+
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanAveragePrecisionAtK.NAME),
52+
MeanAveragePrecisionAtK.Detail::fromXContent));
4553
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME),
4654
MeanReciprocalRank.Detail::fromXContent));
4755
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME),

modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java

+4
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,16 @@ public List<RestHandler> getRestHandlers(Settings settings, RestController restC
5858
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
5959
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
6060
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new));
61+
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, RecallAtK.NAME, RecallAtK::new));
62+
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanAveragePrecisionAtK.NAME, RecallAtK::new));
6163
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new));
6264
namedWriteables.add(
6365
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
6466
namedWriteables.add(
6567
new NamedWriteableRegistry.Entry(EvaluationMetric.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank::new));
6668
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new));
69+
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanAveragePrecisionAtK.NAME, RecallAtK.Detail::new));
70+
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, RecallAtK.NAME, RecallAtK.Detail::new));
6771
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
6872
namedWriteables.add(
6973
new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new));

0 commit comments

Comments
 (0)