Skip to content

FEA Add macro-averaged mean absolute error #780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

AurelienMassiot
Copy link
Contributor

Reference Issue

As detailed in the issue #18901 I wrote in Scikit-Learn main repository, Macro-Averaged MAE should be added to imbalanced-learn repository instead.

For ordinal classification, we can use multiple metrics, for example: MAE, MSE... As we would use for regression.
But when these classes are imbalanced, one way to deal with imbalance is to use macro-averaged MAE instead, as described on StackExchange and in the original paper.

The macro-averaged MAE is like the "classic" MAE, except we compute each MAE for each class and average them, giving equal weights to MAEs. Note that macro-averaged MAE == micro-averaged (or classic) MAE when class are balanced.

To illustrate this, let's consider:

y_true_balanced = np.array([1, 1, 1, 2, 2, 2])
y_true_imbalanced = np.array([1, 1, 1, 1, 1, 2])
y_pred = np.array([1, 2, 1, 2, 1, 2])

mean_absolute_error(y_true_balanced, y_pred)
>> 0.33
mean_absolute_error(y_true_imbalanced, y_pred)
>> 0.33
macro_averaged_mean_absolute_error(y_true_balanced, y_pred)
>> 0.33
macro_averaged_mean_absolute_error(y_true_imbalanced, y_pred)
>> 0.2

Any other comments?

@pep8speaks
Copy link

pep8speaks commented Nov 24, 2020

Hello @AurelienMassiot! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:

Line 777:17: W503 line break before binary operator

Comment last updated at 2021-02-08 23:15:32 UTC

@codecov
Copy link

codecov bot commented Nov 24, 2020

Codecov Report

Merging #780 (26eeabe) into master (f40e654) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #780      +/-   ##
==========================================
- Coverage   98.55%   98.55%   -0.01%     
==========================================
  Files          89       89              
  Lines        5681     5680       -1     
  Branches      475      477       +2     
==========================================
- Hits         5599     5598       -1     
  Misses         81       81              
  Partials        1        1              
Impacted Files Coverage Δ
imblearn/metrics/__init__.py 100.00% <100.00%> (ø)
imblearn/metrics/_classification.py 96.29% <100.00%> (+0.19%) ⬆️
imblearn/metrics/tests/test_classification.py 100.00% <100.00%> (ø)
imblearn/utils/estimator_checks.py 95.60% <0.00%> (-0.36%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f40e654...8a1d8bf. Read the comment docs.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a quick first pass.

all_mae = []
y_true = np.array(y_true)
y_pred = np.array(y_pred)
for class_to_predict in np.unique(y_true):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should introduce a label to be able to either give class that are not present in y_true or select a subset of labels as in precision-recall: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html#sklearn.metrics.precision_recall_fscore_support

Would it make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does not make sense because all we want is the MAE for each class really in y_true.
For example, if I have:
y_true = [1, 2, 2, 3]
I want the average MAE, which will be the MAE for classes 1,2,3. And if a class is missing in y_pred, it doesn't matter, for example in the following example:
y_pred= [1, 2, 2, 2]
my MAEs will be 0 for class 1, 0 for class 2, 1 for class 3, and the MA-MAE will then be 0.33.

WDYT?


],
)
def test_macro_averaged_mean_absolute_error(y_true, y_pred, expected_ma_mae):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we introduce labels, we will need another test with a bit more corner cases.
Otherwise, I think this is good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above for labels.

Copy link
Contributor Author

@AurelienMassiot AurelienMassiot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See label answer

all_mae = []
y_true = np.array(y_true)
y_pred = np.array(y_pred)
for class_to_predict in np.unique(y_true):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does not make sense because all we want is the MAE for each class really in y_true.
For example, if I have:
y_true = [1, 2, 2, 3]
I want the average MAE, which will be the MAE for classes 1,2,3. And if a class is missing in y_pred, it doesn't matter, for example in the following example:
y_pred= [1, 2, 2, 2]
my MAEs will be 0 for class 1, 0 for class 2, 1 for class 3, and the MA-MAE will then be 0.33.

WDYT?


],
)
def test_macro_averaged_mean_absolute_error(y_true, y_pred, expected_ma_mae):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above for labels.

@AurelienMassiot
Copy link
Contributor Author

Could anyone make a final review ? :-)

@glemaitre glemaitre added this to the 0.7 milestone Nov 26, 2020
@glemaitre
Copy link
Member

@AurelienMassiot I promise you that it will be merged for the next release which will follow the release in scikit-learn 0.24
Right now I have to first work on the release of scikit-learn 0.24 because we intend to release in the next week.
If your time schedule is short for completing the PR, do not worry too much, I will push the necessary fix and merge this PR.

@glemaitre glemaitre modified the milestones: 0.7, 0.8 Nov 26, 2020
@AurelienMassiot
Copy link
Contributor Author

Thanks @glemaitre ! This is not urgent, good luck for the release of scikit-learn ;-).

@glemaitre glemaitre self-assigned this Feb 3, 2021
@glemaitre glemaitre self-requested a review February 8, 2021 21:56
@glemaitre glemaitre changed the title [MRG] ENH: macro-averaged mean absolute error FEA Add macro-averaged mean absolute error Feb 8, 2021
@glemaitre glemaitre merged commit 0b48def into scikit-learn-contrib:master Feb 8, 2021
@glemaitre
Copy link
Member

Thanks @AurelienMassiot Good to go

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants