Skip to content

Commit a0d6c19

Browse files
author
Christoph Büscher
authored
Add details section for dcg ranking metric (#31177)
While the other two ranking evaluation metrics (precicion and reciprocal rank) already provide a more detailed output for how their score is calculated, the discounted cumulative gain metric (dcg) and its normalized variant are lacking this until now. Its not really clear which level of detail might be useful for debugging and understanding the final metric calculation, but this change adds a `metric_details` section to REST output that contains some information about the evaluation details.
1 parent ca00deb commit a0d6c19

File tree

6 files changed

+174
-21
lines changed

6 files changed

+174
-21
lines changed

client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.elasticsearch.client;
2121

2222
import com.fasterxml.jackson.core.JsonParseException;
23+
2324
import org.apache.http.Header;
2425
import org.apache.http.HttpEntity;
2526
import org.apache.http.HttpHost;
@@ -607,7 +608,7 @@ public void testDefaultNamedXContents() {
607608

608609
public void testProvidedNamedXContents() {
609610
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
610-
assertEquals(7, namedXContents.size());
611+
assertEquals(8, namedXContents.size());
611612
Map<Class<?>, Integer> categories = new HashMap<>();
612613
List<String> names = new ArrayList<>();
613614
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -625,9 +626,10 @@ public void testProvidedNamedXContents() {
625626
assertTrue(names.contains(PrecisionAtK.NAME));
626627
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
627628
assertTrue(names.contains(MeanReciprocalRank.NAME));
628-
assertEquals(Integer.valueOf(2), categories.get(MetricDetail.class));
629+
assertEquals(Integer.valueOf(3), categories.get(MetricDetail.class));
629630
assertTrue(names.contains(PrecisionAtK.NAME));
630631
assertTrue(names.contains(MeanReciprocalRank.NAME));
632+
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
631633
}
632634

633635
private static class TrackingActionListener implements ActionListener<Integer> {

modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java

Lines changed: 131 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import java.util.Optional;
3737
import java.util.stream.Collectors;
3838

39+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
3940
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
4041
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
4142

@@ -129,26 +130,31 @@ public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
129130
.collect(Collectors.toList());
130131
List<RatedSearchHit> ratedHits = joinHitsWithRatings(hits, ratedDocs);
131132
List<Integer> ratingsInSearchHits = new ArrayList<>(ratedHits.size());
133+
int unratedResults = 0;
132134
for (RatedSearchHit hit : ratedHits) {
133-
// unknownDocRating might be null, which means it will be unrated docs are
134-
// ignored in the dcg calculation
135-
// we still need to add them as a placeholder so the rank of the subsequent
136-
// ratings is correct
135+
// unknownDocRating might be null, in which case unrated docs will be ignored in the dcg calculation.
136+
// we still need to add them as a placeholder so the rank of the subsequent ratings is correct
137137
ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating));
138+
if (hit.getRating().isPresent() == false) {
139+
unratedResults++;
140+
}
138141
}
139-
double dcg = computeDCG(ratingsInSearchHits);
142+
final double dcg = computeDCG(ratingsInSearchHits);
143+
double result = dcg;
144+
double idcg = 0;
140145

141146
if (normalize) {
142147
Collections.sort(allRatings, Comparator.nullsLast(Collections.reverseOrder()));
143-
double idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size())));
144-
if (idcg > 0) {
145-
dcg = dcg / idcg;
148+
idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size())));
149+
if (idcg != 0) {
150+
result = dcg / idcg;
146151
} else {
147-
dcg = 0;
152+
result = 0;
148153
}
149154
}
150-
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, dcg);
155+
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, result);
151156
evalQueryQuality.addHitsAndRatings(ratedHits);
157+
evalQueryQuality.setMetricDetails(new Detail(dcg, idcg, unratedResults));
152158
return evalQueryQuality;
153159
}
154160

@@ -167,7 +173,7 @@ private static double computeDCG(List<Integer> ratings) {
167173
private static final ParseField K_FIELD = new ParseField("k");
168174
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
169175
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
170-
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at", false,
176+
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg", false,
171177
args -> {
172178
Boolean normalized = (Boolean) args[0];
173179
Integer optK = (Integer) args[2];
@@ -217,4 +223,118 @@ public final boolean equals(Object obj) {
217223
public final int hashCode() {
218224
return Objects.hash(normalize, unknownDocRating, k);
219225
}
226+
227+
public static final class Detail implements MetricDetail {
228+
229+
private static ParseField DCG_FIELD = new ParseField("dcg");
230+
private static ParseField IDCG_FIELD = new ParseField("ideal_dcg");
231+
private static ParseField NDCG_FIELD = new ParseField("normalized_dcg");
232+
private static ParseField UNRATED_FIELD = new ParseField("unrated_docs");
233+
private final double dcg;
234+
private final double idcg;
235+
private final int unratedDocs;
236+
237+
Detail(double dcg, double idcg, int unratedDocs) {
238+
this.dcg = dcg;
239+
this.idcg = idcg;
240+
this.unratedDocs = unratedDocs;
241+
}
242+
243+
Detail(StreamInput in) throws IOException {
244+
this.dcg = in.readDouble();
245+
this.idcg = in.readDouble();
246+
this.unratedDocs = in.readVInt();
247+
}
248+
249+
@Override
250+
public
251+
String getMetricName() {
252+
return NAME;
253+
}
254+
255+
@Override
256+
public XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
257+
builder.field(DCG_FIELD.getPreferredName(), this.dcg);
258+
if (this.idcg != 0) {
259+
builder.field(IDCG_FIELD.getPreferredName(), this.idcg);
260+
builder.field(NDCG_FIELD.getPreferredName(), this.dcg / this.idcg);
261+
}
262+
builder.field(UNRATED_FIELD.getPreferredName(), this.unratedDocs);
263+
return builder;
264+
}
265+
266+
private static final ConstructingObjectParser<Detail, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
267+
return new Detail((Double) args[0], (Double) args[1] != null ? (Double) args[1] : 0.0d, (Integer) args[2]);
268+
});
269+
270+
static {
271+
PARSER.declareDouble(constructorArg(), DCG_FIELD);
272+
PARSER.declareDouble(optionalConstructorArg(), IDCG_FIELD);
273+
PARSER.declareInt(constructorArg(), UNRATED_FIELD);
274+
}
275+
276+
public static Detail fromXContent(XContentParser parser) {
277+
return PARSER.apply(parser, null);
278+
}
279+
280+
@Override
281+
public void writeTo(StreamOutput out) throws IOException {
282+
out.writeDouble(this.dcg);
283+
out.writeDouble(this.idcg);
284+
out.writeVInt(this.unratedDocs);
285+
}
286+
287+
@Override
288+
public String getWriteableName() {
289+
return NAME;
290+
}
291+
292+
/**
293+
* @return the discounted cumulative gain
294+
*/
295+
public double getDCG() {
296+
return this.dcg;
297+
}
298+
299+
/**
300+
* @return the ideal discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required
301+
*/
302+
public double getIDCG() {
303+
return this.idcg;
304+
}
305+
306+
/**
307+
* @return the normalized discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required
308+
*/
309+
public double getNDCG() {
310+
return (this.idcg != 0) ? this.dcg / this.idcg : 0;
311+
}
312+
313+
/**
314+
* @return the number of unrated documents in the search results
315+
*/
316+
public Object getUnratedDocs() {
317+
return this.unratedDocs;
318+
}
319+
320+
@Override
321+
public boolean equals(Object obj) {
322+
if (this == obj) {
323+
return true;
324+
}
325+
if (obj == null || getClass() != obj.getClass()) {
326+
return false;
327+
}
328+
DiscountedCumulativeGain.Detail other = (DiscountedCumulativeGain.Detail) obj;
329+
return (this.dcg == other.dcg &&
330+
this.idcg == other.idcg &&
331+
this.unratedDocs == other.unratedDocs);
332+
}
333+
334+
@Override
335+
public int hashCode() {
336+
return Objects.hash(this.dcg, this.idcg, this.unratedDocs);
337+
}
338+
}
220339
}
340+

modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
4141
PrecisionAtK.Detail::fromXContent));
4242
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME),
4343
MeanReciprocalRank.Detail::fromXContent));
44+
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME),
45+
DiscountedCumulativeGain.Detail::fromXContent));
4446
return namedXContent;
4547
}
4648
}

modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
6161
namedWriteables.add(
6262
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
6363
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new));
64-
namedWriteables
65-
.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
64+
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
65+
namedWriteables.add(
66+
new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new));
6667
return namedWriteables;
6768
}
6869

modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.elasticsearch.index.rankeval;
2121

22+
import org.elasticsearch.common.Strings;
2223
import org.elasticsearch.common.bytes.BytesReference;
2324
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
2425
import org.elasticsearch.common.text.Text;
@@ -254,9 +255,8 @@ private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRati
254255

255256
public static DiscountedCumulativeGain createTestItem() {
256257
boolean normalize = randomBoolean();
257-
Integer unknownDocRating = Integer.valueOf(randomIntBetween(0, 1000));
258-
259-
return new DiscountedCumulativeGain(normalize, unknownDocRating, 10);
258+
Integer unknownDocRating = frequently() ? Integer.valueOf(randomIntBetween(0, 1000)) : null;
259+
return new DiscountedCumulativeGain(normalize, unknownDocRating, randomIntBetween(1, 10));
260260
}
261261

262262
public void testXContentRoundtrip() throws IOException {
@@ -283,7 +283,25 @@ public void testXContentParsingIsNotLenient() throws IOException {
283283
parser.nextToken();
284284
XContentParseException exception = expectThrows(XContentParseException.class,
285285
() -> DiscountedCumulativeGain.fromXContent(parser));
286-
assertThat(exception.getMessage(), containsString("[dcg_at] unknown field"));
286+
assertThat(exception.getMessage(), containsString("[dcg] unknown field"));
287+
}
288+
}
289+
290+
public void testMetricDetails() {
291+
double dcg = randomDoubleBetween(0, 1, true);
292+
double idcg = randomBoolean() ? 0.0 : randomDoubleBetween(0, 1, true);
293+
double expectedNdcg = idcg != 0 ? dcg / idcg : 0.0;
294+
int unratedDocs = randomIntBetween(0, 100);
295+
DiscountedCumulativeGain.Detail detail = new DiscountedCumulativeGain.Detail(dcg, idcg, unratedDocs);
296+
assertEquals(dcg, detail.getDCG(), 0.0);
297+
assertEquals(idcg, detail.getIDCG(), 0.0);
298+
assertEquals(expectedNdcg, detail.getNDCG(), 0.0);
299+
assertEquals(unratedDocs, detail.getUnratedDocs());
300+
if (idcg != 0) {
301+
assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"ideal_dcg\":" + idcg + ",\"normalized_dcg\":" + expectedNdcg
302+
+ ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail));
303+
} else {
304+
assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail));
287305
}
288306
}
289307

modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,20 @@ public static EvalQueryQuality randomEvalQueryQuality() {
6868
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(randomAlphaOfLength(10),
6969
randomDoubleBetween(0.0, 1.0, true));
7070
if (randomBoolean()) {
71-
if (randomBoolean()) {
71+
int metricDetail = randomIntBetween(0, 2);
72+
switch (metricDetail) {
73+
case 0:
7274
evalQueryQuality.setMetricDetails(new PrecisionAtK.Detail(randomIntBetween(0, 1000), randomIntBetween(0, 1000)));
73-
} else {
75+
break;
76+
case 1:
7477
evalQueryQuality.setMetricDetails(new MeanReciprocalRank.Detail(randomIntBetween(0, 1000)));
78+
break;
79+
case 2:
80+
evalQueryQuality.setMetricDetails(new DiscountedCumulativeGain.Detail(randomDoubleBetween(0, 1, true),
81+
randomBoolean() ? randomDoubleBetween(0, 1, true) : 0, randomInt()));
82+
break;
83+
default:
84+
throw new IllegalArgumentException("illegal randomized value in test");
7585
}
7686
}
7787
evalQueryQuality.addHitsAndRatings(ratedHits);

0 commit comments

Comments
 (0)