Skip to content

Commit d9835f7

Browse files
authored
[ML] Fix r_squared eval when variance is 0 (#49439) (#49445)
1 parent 138d16a commit d9835f7

File tree

2 files changed

+20
-1
lines changed
  • x-pack/plugin/core/src
    • main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression
    • test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression

2 files changed

+20
-1
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ public void process(Aggregations aggs) {
8181
NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES);
8282
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
8383
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
84-
result = residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
84+
final boolean validResult = residualSumOfSquares == null
85+
|| extendedStats == null
86+
|| extendedStats.getCount() == 0
87+
|| extendedStats.getVariance() == 0;
88+
result = validResult ?
8589
new Result(0.0) :
8690
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
8791
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,21 @@ public void testEvaluateWithZeroCount() {
7474
assertThat(result, equalTo(new RSquared.Result(0.0)));
7575
}
7676

77+
public void testEvaluateWithSingleCountZeroVariance() {
78+
Aggregations aggs = new Aggregations(Arrays.asList(
79+
createSingleMetricAgg("residual_sum_of_squares", 1),
80+
createExtendedStatsAgg("extended_stats_actual", 0.0, 1),
81+
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
82+
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
83+
));
84+
85+
RSquared rSquared = new RSquared();
86+
rSquared.process(aggs);
87+
88+
EvaluationMetricResult result = rSquared.getResult().get();
89+
assertThat(result, equalTo(new RSquared.Result(0.0)));
90+
}
91+
7792
public void testEvaluate_GivenMissingAggs() {
7893
Aggregations aggs = new Aggregations(Collections.singletonList(
7994
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)

0 commit comments

Comments
 (0)