TensorFlow implementation of focal loss [1]: a loss function generalizing binary and multiclass cross-entropy loss that penalizes hard-to-classify examples.
The focal_loss
package provides functions and classes that can be used as
off-the-shelf replacements for tf.keras.losses
functions and classes,
respectively.
# Typical tf.keras API usage
import tensorflow as tf
from focal_loss import BinaryFocalLoss
model = tf.keras.Model(...)
model.compile(
optimizer=...,
loss=BinaryFocalLoss(gamma=2), # Used here like a tf.keras loss
metrics=...,
)
history = model.fit(...)
The focal_loss
package includes the functions
binary_focal_loss
sparse_categorical_focal_loss
and wrapper classes
BinaryFocalLoss
(use liketf.keras.losses.BinaryCrossentropy
)SparseCategoricalFocalLoss
(use liketf.keras.losses.SparseCategoricalCrossentropy
)
Documentation is available at Read the Docs.
The focal_loss
package can be installed using the
pip utility.
For the latest version, install directly from the package's
GitHub page:
pip install git+https://github.com/artemmavrin/focal-loss.git
Alternatively, install a recent release from the Python Package Index (PyPI):
pip install focal-loss
Note. To install the project for development (e.g., to make changes to
the source code), clone the project repository from GitHub and run
make dev
:
git clone https://github.com/artemmavrin/focal-loss.git
cd focal-loss
# Optional but recommended: create and activate a new environment first
make dev
This will additionally install the requirements needed to run tests, check code coverage, and produce documentation.
[1] | T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for dense object detection. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2018. (DOI) (arXiv preprint) |