diff --git a/keras/src/losses/losses.py b/keras/src/losses/losses.py index 8b29a4af2a0e..cea6d197c066 100644 --- a/keras/src/losses/losses.py +++ b/keras/src/losses/losses.py @@ -2143,10 +2143,21 @@ def categorical_focal_crossentropy( >>> y_true = [[0, 1, 0], [0, 0, 1]] >>> y_pred = [[0.05, 0.9, 0.05], [0.1, 0.85, 0.05]] - >>> loss = keras.losses.categorical_focal_crossentropy(y_true, y_pred) - >>> assert loss.shape == (2,) - >>> loss + >>> # In this instance, the second example is the 'harder' example. + >>> focal_loss = keras.losses.categorical_focal_crossentropy(y_true, y_pred) + >>> assert focal_loss.shape == (2,) + >>> focal_loss array([2.63401289e-04, 6.75912094e-01], dtype=float32) + >>> # Compare with categorical_crossentropy + >>> cce_loss = keras.losses.categorical_crossentropy( + ... y_true, y_pred) + >>> cce_loss + array([0.10536054, 2.9957323], dtype=float32) + >>> # Categorical focal crossentropy loss attributes more importance to the + >>> # harder example which results in a higher loss for the second example + >>> # when normalized by categorical cross entropy loss + >>> focal_loss/cce_loss + array([0.0025 , 0.225625], dtype=float32) """ if isinstance(axis, bool): raise ValueError( @@ -2367,11 +2378,11 @@ def binary_focal_crossentropy( >>> # 'easier' example. >>> focal_loss = keras.losses.binary_focal_crossentropy( ... y_true, y_pred, gamma=2) - >>> assert loss.shape == (2,) + >>> assert focal_loss.shape == (2,) >>> focal_loss array([0.330, 0.206], dtype=float32) >>> # Compare with binary_crossentropy - >>> bce_loss = keras.losses.binary_focal_crossentropy( + >>> bce_loss = keras.losses.binary_crossentropy( ... y_true, y_pred) >>> bce_loss array([0.916, 0.714], dtype=float32) @@ -2379,7 +2390,7 @@ def binary_focal_crossentropy( >>> # harder example which results in a higher loss for the first batch >>> # when normalized by binary cross entropy loss >>> focal_loss/bce_loss - array([0.360, 0.289] + array([0.360, 0.289], dtype=float32) """ y_pred = ops.convert_to_tensor(y_pred) y_true = ops.cast(y_true, y_pred.dtype)