Focal Loss for Binary and Multi-class Problems: An In-Depth Explanation and Implementation
Table of Content
- 1. Application Scenarios of Focal Loss
- 2. Explanation of Focal Loss in Binary Classification
- 2.1. How Focal Loss Adjusts Positive and Negative Sample Weights
- 2.2 How Focal Loss Adjusts Easy and Hard Sample Weights
- 2.3 Integrating the Two Functions to Complete Focal Loss
- 2.4 Pytorch Implementation of Focal Loss
- 3. Focal Loss in Multi-class Classification Task
1. Application Scenarios of Focal Loss
Before learning something, it’s essential to understand its purpose.
Focal Loss primarily serves two purposes, which also determine its application scenarios:
- Focal Loss can adjust the loss weights of positive and negative samples. This means that when there is an imbalance in the number of positive and negative samples, Focal Loss can be considered.
- Focal Loss can adjust the loss weights of easy and hard samples. This means that when the difficulty of training samples is imbalanced, Focal Loss can also be considered.
It reflects the meaning of "Focal Loss," focusing on those "few and difficult" samples.
Although most blogs discuss Focal Loss in the context of Object Detection, it can actually be applied in other scenarios as well. For example, in NLP tasks:
- When we perform sentiment classification (positive/negative), if 99% of the reviews are positive and only 1% are negative, we can use Focal Loss to adjust for the data imbalance.
- In sentiment classification, some samples can be very hard to predict, such as, "After my dog ate your dish and made me four dishes and a soup overnight." Conversely, some samples are straightforward, like "Your dish is awful." In such cases, Focal Loss can help adjust the loss weights of hard and easy samples, allowing us to better learn the features of difficult samples.
2. Explanation of Focal Loss in Binary Classification
This section will discuss how Focal Loss achieves its two functions.
2.1. How Focal Loss Adjusts Positive and Negative Sample Weights
In binary classification, we usually calculate the loss using cross-entropy, defined as follows:
$$ \begin{aligned} \mathrm{CE}(p, y)= \begin{cases}-\log (p) & \text { if } y=1 \\ -\log (1-p) & \text { if } y=0\end{cases} \end{aligned} $$
Here, CE stands for Cross Entropy, $p$ is the predicted outcome (for example, 0.8), and $y$ is the true label.
Assuming 90% of our samples are negative, so the calculated loss will heavily favor negative samples. To adjust this, we can simply multiply by a weight, such as:
$$ \begin{aligned} \mathrm{CE}(p, y)= \begin{cases}-\log (p) * 0.9 & \text { if } y=1 \\ -\log (1-p) * 0.1 & \text { if } y=0\end{cases} \end{aligned} $$
This gives positive and negative samples a weight of 9:1. By defining 0.9 as a variable $\alpha$, we have:
$$ \begin{aligned} \mathrm{CE}(p, y, \alpha)= \begin{cases}-\log (p) * \alpha& \text { if } y=1 \\ -\log (1-p) * (1-\alpha) & \text { if } y=0\end{cases} \end{aligned} $$
Where $\alpha \in (0,1)$ is a hyperparameter. This is how Focal Loss adjusts the weights of positive and negative samples, quite straightforward.
2.2 How Focal Loss Adjusts Easy and Hard Sample Weights
When training on binary classification task, after applying sigmoid function to the output, the output will be a probability between 0 and 1, indicating the likelihood of being a positive sample.
For samples labeled as 1:
- If the prediction is 0.95, it indicates the sample is relatively easy to predict.
- If the prediction is 0.65, it suggests the sample is somewhat challenging to predict.
- If the prediction is 0.28, it indicates the sample is very difficult to predict.
The same reasoning applies to negative samples. Thus, the further the predicted value is from the true value, the harder the sample is to predict.
To encourage more learning from hard samples, we assign a larger weight to their loss, while easy samples receive a smaller weight. We can base the weight on their difficulty level, for example:
For samples labeled as 1:
- If the prediction output is 0.95, the weight is (1-0.95) = 0.05.
- If the prediction output is 0.65, the weight is (1-0.65) = 0.35.
- If the prediction output is 0.28, the weight is (1-0.28) = 0.72.
Following this logic, we can derive the following loss function:
$$ \begin{aligned} \mathrm{CE}(p, y)= \begin{cases}-\log (p) * (1-p) & \text { if } y=1 \\ -\log (1-p) * p & \text { if } y=0\end{cases} \end{aligned} $$
Furthermore, if you want the weights for easy samples to be even lower and harder samples to be higher, you can square the weight. This results in smaller values for easy samples and larger values for hard ones, leading us to the following formula:
$$ \begin{aligned} \mathrm{CE}(p, y)= \begin{cases}-\log (p) * (1-p)^2 & \text { if } y=1 \\ -\log (1-p) * p^2 & \text { if } y=0\end{cases} \end{aligned} $$
If you think squaring too extreme, you can replace the square as a hyperparameter $\gamma$, resulting in:
$$ \begin{aligned} \mathrm{CE}(p, y)= \begin{cases}-\log (p) * (1-p)^\gamma & \text { if } y=1 \\ -\log (1-p) * p^\gamma & \text { if } y=0\end{cases} \end{aligned} $$
This completes the adjustment of easy and hard sample weights. Let’s summarize the parameter $\gamma$:
- When $\gamma=0$, the loss function reduces to the original Cross Entropy.
- The larger $\gamma$ is, the more aggressive the weight adjustments; conversely, the lighter it becomes. "Aggressive" means that easy samples’ weights become lower, while hard samples’ weights become higher.
- $\gamma$ is usually set to $2$.
2.3 Integrating the Two Functions to Complete Focal Loss
The integration process is simple: just combine $\alpha$ and $\gamma$ in one formula, resulting in the Focal Loss formula:
$$ \begin{aligned} \mathrm{FL}(p, y, \alpha, \gamma)= \begin{cases}-\log (p) *\alpha * (1-p)^\gamma & \text { if } y=1 \\ -\log (1-p) *(1-\alpha)* p^\gamma & \text { if } y=0\end{cases} \end{aligned} $$
To make this formula more readable, we can define two new variables, $\alpha_t$ and $p_t$, where:
$$ \begin{aligned} \alpha_t= \begin{cases} \alpha & \text { if } y=1 \\ 1-\alpha & \text { if } y=0\end{cases}, ~~~~~~p_t= \begin{cases} p & \text { if } y=1 \\ 1-p & \text { if } y=0\end{cases} \end{aligned} $$
Then Focal Loss can be expressed in its final form:
$$ \begin{aligned} \mathrm{FL}(p_t) = -\alpha_t(1-p_t)^\gamma \log (p_t) \end{aligned} $$
This is the formula for Focal Loss.
2.4 Pytorch Implementation of Focal Loss
```python import torch from torch import nn class BinaryFocalLoss(nn.Module): """ Reference: https://github.com/lonePatient/TorchBlocks """ def __init__(self, gamma=2.0, alpha=0.25, epsilon=1.e-9): super(BinaryFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.epsilon = epsilon def forward(self, input, target): """ Args: input: model's output, shape of [batch_size, num_cls] target: ground truth labels, shape of [batch_size] Returns: shape of [batch_size] """ multi_hot_key = target logits = input # If the model hasn't applied sigmoid to output, add it here. # logits = torch.sigmoid(logits) zero_hot_key = 1 - multi_hot_key loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log() return loss.mean() if __name__ == '__main__': m = nn.Sigmoid() loss = BinaryFocalLoss() input = torch.randn(3, requires_grad=True) target = torch.empty(3).random_(2) output = loss(m(input), target) print("loss:", output) output.backward() ```
3. Focal Loss in Multi-class Classification Task
3.1 Adjusting Class Weights by Focal Loss
Let's assume a three-class classification problem, with class labels $y=(1,2,3)$, where the sample sizes are 100, 2000, and 10000 respectively.
In this case, the parameter $\alpha_t$ can no longer be a single value but must be a list, such as:
$$ \begin{aligned} \alpha_t= \begin{cases} 0.7 & \text { if } y=1 \\ 0.25 & \text { if } y= 2 \\ 0.05 & \text { if } y= 3 \end{cases} \end{aligned} $$
In multi-class problems, the hyperparameter $\alpha_t$ changes from a single value to an array, with each class having its own weight. Classes with more samples are assigned lower weights, while classes with fewer samples get higher weights. The formula can be generalized as:
$$ \begin{aligned} \alpha_t= \begin{cases} \alpha_1 & \text { if } y=1 \\ \alpha_2 & \text { if } y= 2 \\ ... & \text { if } y= ... \\ \alpha_n & \text { if } y= n \end{cases} \end{aligned} $$
Where $n$ is the number of classes.
In most blogs or open-source projects, $\alpha_t$ is still treated as a single value even for multi-class problems. However, theoretically, this does not solve the problem of imbalanced data, because it just applies a uniform weight to all classes.
3.2 Adjusting Weights for Easy and Hard Samples in Multi-Class Problems
Now, consider a three-class classification problem, $y=(1,2,3)$, with a predicted probability for a particular sample:
$$ \begin{aligned} p= [0.85, 0.1, 0.05]^T \end{aligned} $$
- If the label of the sample is 1, its loss weight can be set to $(1-0.85)=0.15$, indicating this is an easy sample.
- If the label of the sample is 2, its loss weight can be set to $1-0.1=0.9$, indicating this is an difficult sample.
- the lable 3 is same as above.
To implement this, we use a one-hot vector to mask out the irrelevant class probabilities:
$$ \begin{bmatrix} 1 \\ 0 \\ 0 \end{bmatrix} \times(1- \begin{bmatrix} 0.85 \\ 0.1 \\ 0.05 \end{bmatrix})^\gamma = \begin{bmatrix} 0.15^\gamma \\ 0 \\ 0 \end{bmatrix} $$
Here, we denote the one-hot vector as $h$. So, we can formulate the above process as:
$$ \begin{aligned} h *(1-p_t)^\gamma \end{aligned} $$
where the parameter $\gamma$ is an adjustment to the weight same as which we mentioned in Chapter 2.
3.3 Integrating Them to Complete the Focal Loss in Multi-class Problems
Putting everything together, the Focal Loss for a multi-class problems can be formulated as:
$$ \begin{aligned} \mathrm{FL}(p, y, \alpha, \gamma)= \begin{cases}-\log (p) *\alpha * (1-p)^\gamma & \text { if } y=1 \\ -\log (1-p) *(1-\alpha)* p^\gamma & \text { if } y=0\end{cases} \end{aligned} $$
In this formula, $\alpha_t$ and $p_t$ have different meanings compared to binary classification. $\alpha_t$ is a list containing the weights for each class, while $p_t$ represents the probability distribution, and $h$ is the one-hot vector for the current sample.
Let's work through a concrete example:
Suppose we have a three-class classification task and set $\alpha_t=[0.7, 0.25, 0.05]^T$ and $\gamma=2$. For a sample with label $y=1$, and an output probability distribution $p=[0.85,0.1,0.05]^T$, the loss will be:
$$ \begin{aligned} \mathrm{FL_{loss}} = & \text{sum} (-\begin{bmatrix} 0.7 \\ 0.25 \\ 0.05 \end{bmatrix}\times \begin{bmatrix} 1 \\ 0 \\ 0 \end{bmatrix}\times (1- \begin{bmatrix} 0.85 \\ 0.1 \\ 0.05 \end{bmatrix})^2 \times \log(\begin{bmatrix} 0.85 \\ 0.1 \\ 0.05 \end{bmatrix})) \\ \\ = & \text{sum} (-\begin{bmatrix} 0.7 \\ 0.25 \\ 0.05 \end{bmatrix}\times \begin{bmatrix} 0.15^2 \\ 0 \\ 0 \end{bmatrix} \times \log(\begin{bmatrix} 0.85 \\ 0.1 \\ 0.05 \end{bmatrix})) \\ \\ = & \text{sum}(\begin{bmatrix} 0.0026 \\ 0 \\ 0 \end{bmatrix}) \\ \\ =&0.0026 \end{aligned} $$
As we can see from this example, only the row corresponding to the sample's class actually contributes to the loss due to the one-hot vector.
Thus, we can simplify the Focal Loss formula as:
$$ \begin{aligned} \mathrm{FL} = -\alpha_c(1-p_c)^\gamma \log(p_c) \end{aligned} $$
Where $c$ is the class of the current sample, $\alpha_c$ is the weight for class $c$, and $p_c$ is the predicted probability for class $c$.
3.4 Implementing Focal Loss for Multi-class Problem
```python class FocalLoss(nn.Module): """ Reference: https://github.com/lonePatient/TorchBlocks """ def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None): super(FocalLoss, self).__init__() self.gamma = gamma if isinstance(alpha, list): self.alpha = torch.Tensor(alpha, device=device) else: self.alpha = alpha self.epsilon = epsilon def forward(self, input, target): """ Args: input: model's output, shape of [batch_size, num_cls] target: ground truth labels, shape of [batch_size] Returns: shape of [batch_size] """ num_labels = input.size(-1) idx = target.view(-1, 1).long() one_hot_key = torch.zeros(idx.size(0), num_labels, dtype=torch.float32, device=idx.device) one_hot_key = one_hot_key.scatter_(1, idx, 1) one_hot_key[:, 0] = 0 # ignore 0 index. logits = torch.softmax(input, dim=-1) loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() loss = loss.sum(1) return loss.mean() if __name__ == '__main__': loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25]) input = torch.randn(3, 5, requires_grad=True) target = torch.empty(3, dtype=torch.long).random_(5) output = loss(input, target) print(output) output.backward() ```