Focal Loss for Binary and Multi-class Problems: An In-Depth Explanation and Implementation
Table of Content
- 1. Application Scenarios of Focal Loss
- 2. Explanation of Focal Loss in Binary Classification
- 2.1. How Focal Loss Adjusts Positive and Negative Sample Weights
- 2.2 How Focal Loss Adjusts Easy and Hard Sample Weights
- 2.3 Integrating the Two Functions to Complete Focal Loss
- 2.4 Pytorch Implementation of Focal Loss
- 3. Focal Loss in Multi-class Classification Task
1. Application Scenarios of Focal Loss
Before learning something, it’s essential to understand its purpose.
Focal Loss primarily serves two purposes, which also determine its application scenarios:
- Focal Loss can adjust the loss weights of positive and negative samples. This means that when there is an imbalance in the number of positive and negative samples, Focal Loss can be considered.
- Focal Loss can adjust the loss weights of easy and hard samples. This means that when the difficulty of training samples is imbalanced, Focal Loss can also be considered.
It reflects the meaning of "Focal Loss," focusing on those "few and difficult" samples.
Although most blogs discuss Focal Loss in the context of Object Detection, it can actually be applied in other scenarios as well. For example, in NLP tasks:
- When we perform sentiment classification (positive/negative), if 99% of the reviews are positive and only 1% are negative, we can use Focal Loss to adjust for the data imbalance.
- In sentiment classification, some samples can be very hard to predict, such as, "After my dog ate your dish and made me four dishes and a soup overnight." Conversely, some samples are straightforward, like "Your dish is awful." In such cases, Focal Loss can help adjust the loss weights of hard and easy samples, allowing us to better learn the features of difficult samples.
2. Explanation of Focal Loss in Binary Classification
This section will discuss how Focal Loss achieves its two functions.
2.1. How Focal Loss Adjusts Positive and Negative Sample Weights
In binary classification, we usually calculate the loss using cross-entropy, defined as follows:
Here, CE stands for Cross Entropy,
Assuming 90% of our samples are negative, so the calculated loss will heavily favor negative samples. To adjust this, we can simply multiply by a weight, such as:
This gives positive and negative samples a weight of 9:1. By defining 0.9 as a variable
Where
2.2 How Focal Loss Adjusts Easy and Hard Sample Weights
When training on binary classification task, after applying sigmoid function to the output, the output will be a probability between 0 and 1, indicating the likelihood of being a positive sample.
For samples labeled as 1:
- If the prediction is 0.95, it indicates the sample is relatively easy to predict.
- If the prediction is 0.65, it suggests the sample is somewhat challenging to predict.
- If the prediction is 0.28, it indicates the sample is very difficult to predict.
The same reasoning applies to negative samples. Thus, the further the predicted value is from the true value, the harder the sample is to predict.
To encourage more learning from hard samples, we assign a larger weight to their loss, while easy samples receive a smaller weight. We can base the weight on their difficulty level, for example:
For samples labeled as 1:
- If the prediction output is 0.95, the weight is (1-0.95) = 0.05.
- If the prediction output is 0.65, the weight is (1-0.65) = 0.35.
- If the prediction output is 0.28, the weight is (1-0.28) = 0.72.
Following this logic, we can derive the following loss function:
Furthermore, if you want the weights for easy samples to be even lower and harder samples to be higher, you can square the weight. This results in smaller values for easy samples and larger values for hard ones, leading us to the following formula:
If you think squaring too extreme, you can replace the square as a hyperparameter
This completes the adjustment of easy and hard sample weights. Let’s summarize the parameter
- When
, the loss function reduces to the original Cross Entropy. - The larger
is, the more aggressive the weight adjustments; conversely, the lighter it becomes. "Aggressive" means that easy samples’ weights become lower, while hard samples’ weights become higher. is usually set to .
2.3 Integrating the Two Functions to Complete Focal Loss
The integration process is simple: just combine
To make this formula more readable, we can define two new variables,
Then Focal Loss can be expressed in its final form:
This is the formula for Focal Loss.
2.4 Pytorch Implementation of Focal Loss
```python import torch from torch import nn class BinaryFocalLoss(nn.Module): """ Reference: https://github.com/lonePatient/TorchBlocks """ def __init__(self, gamma=2.0, alpha=0.25, epsilon=1.e-9): super(BinaryFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.epsilon = epsilon def forward(self, input, target): """ Args: input: model's output, shape of [batch_size, num_cls] target: ground truth labels, shape of [batch_size] Returns: shape of [batch_size] """ multi_hot_key = target logits = input # If the model hasn't applied sigmoid to output, add it here. # logits = torch.sigmoid(logits) zero_hot_key = 1 - multi_hot_key loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log() return loss.mean() if __name__ == '__main__': m = nn.Sigmoid() loss = BinaryFocalLoss() input = torch.randn(3, requires_grad=True) target = torch.empty(3).random_(2) output = loss(m(input), target) print("loss:", output) output.backward() ```
3. Focal Loss in Multi-class Classification Task
3.1 Adjusting Class Weights by Focal Loss
Let's assume a three-class classification problem, with class labels
In this case, the parameter
In multi-class problems, the hyperparameter
Where
In most blogs or open-source projects,
is still treated as a single value even for multi-class problems. However, theoretically, this does not solve the problem of imbalanced data, because it just applies a uniform weight to all classes.
3.2 Adjusting Weights for Easy and Hard Samples in Multi-Class Problems
Now, consider a three-class classification problem,
- If the label of the sample is 1, its loss weight can be set to
, indicating this is an easy sample. - If the label of the sample is 2, its loss weight can be set to
, indicating this is an difficult sample. - the lable 3 is same as above.
To implement this, we use a one-hot vector to mask out the irrelevant class probabilities:
Here, we denote the one-hot vector as
where the parameter
3.3 Integrating Them to Complete the Focal Loss in Multi-class Problems
Putting everything together, the Focal Loss for a multi-class problems can be formulated as:
In this formula,
Let's work through a concrete example:
Suppose we have a three-class classification task and set
As we can see from this example, only the row corresponding to the sample's class actually contributes to the loss due to the one-hot vector.
Thus, we can simplify the Focal Loss formula as:
Where
3.4 Implementing Focal Loss for Multi-class Problem
```python class FocalLoss(nn.Module): """ Reference: https://github.com/lonePatient/TorchBlocks """ def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None): super(FocalLoss, self).__init__() self.gamma = gamma if isinstance(alpha, list): self.alpha = torch.Tensor(alpha, device=device) else: self.alpha = alpha self.epsilon = epsilon def forward(self, input, target): """ Args: input: model's output, shape of [batch_size, num_cls] target: ground truth labels, shape of [batch_size] Returns: shape of [batch_size] """ num_labels = input.size(-1) idx = target.view(-1, 1).long() one_hot_key = torch.zeros(idx.size(0), num_labels, dtype=torch.float32, device=idx.device) one_hot_key = one_hot_key.scatter_(1, idx, 1) one_hot_key[:, 0] = 0 # ignore 0 index. logits = torch.softmax(input, dim=-1) loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() loss = loss.sum(1) return loss.mean() if __name__ == '__main__': loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25]) input = torch.randn(3, 5, requires_grad=True) target = torch.empty(3, dtype=torch.long).random_(5) output = loss(input, target) print(output) output.backward() ```