**Class-balanced loss** is a technique used in machine learning to handle the challenge of **class imbalance**—where certain classes are significantly underrepresented compared to others. Class imbalance is a common issue in real-world datasets and can severely impact the performance of models, particularly when dealing with imbalanced classification tasks, such as fraud detection, medical diagnosis, and rare event prediction. To understand class-balanced loss, let's go through why it's needed and how it works. ### 1. The Problem of Class Imbalance In an imbalanced dataset, certain classes dominate, and models trained on such data tend to be biased toward predicting the majority class, as it has more examples during training. This bias leads to poor performance on the minority classes, which may be the classes of interest (e.g., rare disease detection). Standard loss functions like **cross-entropy loss** treat all classes equally, which results in the model ignoring the minority classes since it is easier to minimize the loss by predicting the majority class. ### 2. Techniques for Handling Class Imbalance There are several ways to deal with class imbalance: - **Data-level techniques**: Oversampling the minority class, undersampling the majority class. - **Algorithm-level techniques**: Adjusting the decision threshold, using cost-sensitive learning, and **class-balanced loss**. ### 3. What is Class-Balanced Loss? **Class-balanced loss** modifies the loss function to account for the imbalance in the dataset. Instead of treating all classes equally, it scales the loss by factors that depend on the distribution of class frequencies. This ensures that rare classes contribute more to the final loss, preventing the model from ignoring them. There are multiple versions of class-balanced loss, but one common approach is to use **inverse effective number of samples** to balance the contribution of each class. Below are some concepts behind class-balanced loss. #### Effective Number of Samples The **effective number of samples** is a concept that accounts for the diminishing marginal value of additional samples in a class. It can be computed as follows: $ E_n = \frac{1 - \beta^n}{1 - \beta} $ Where: - $n$: The number of samples for a given class. - $\beta$: A parameter close to 1 (e.g., 0.99), which controls the rate of diminishing returns. The idea here is that adding new samples to a minority class adds more information, whereas adding new samples to a majority class adds relatively less value. Using this effective number, the **class weight** can be defined as: $ w_c = \frac{1}{E_{n_c}} $ Where $n_c$ is the number of samples in class $c$. Here is a plot to show the effective number of samples. ![[Pasted image 20241025113021.png]] - **X-axis**: Represents the number of samples in each class (log scale for better visualization). - **Y-axis**: Represents the effective number of samples for each class. - As the number of samples increases, the **effective number** approaches the actual count, but with diminishing returns. This is because adding more samples to a large class adds relatively less new information compared to adding samples to a smaller class. This demonstrates why **minority classes** receive higher importance weights — their effective number is proportionally more valuable. #### Class-Balanced Cross Entropy Loss We can integrate this weighting scheme into the standard cross-entropy loss. The **class-balanced cross-entropy loss** can be defined as: $ L = - \frac{1}{N} \sum_{i=1}^{N} w_{y_i} \cdot \log(p_{y_i}) $ Where: - $N$ is the total number of samples. - $w_{y_i}$ is the weight for the class of sample $i$, calculated using the effective number of samples. - $p_{y_i}$ is the predicted probability of the true class for sample $i$. The class weight $w_c$ effectively balances the contribution of each class during the training process, thereby encouraging the model to pay more attention to minority classes. ### 4. Code Example: Implementing Class-Balanced Loss in PyTorch Here is an example of how to implement class-balanced loss in PyTorch. ```python import torch import torch.nn as nn import torch.nn.functional as F class ClassBalancedLoss(nn.Module): def __init__(self, beta, class_counts): super(ClassBalancedLoss, self).__init__() self.beta = beta self.class_counts = class_counts self.class_weights = self.compute_class_weights() def compute_class_weights(self): # Compute effective number of samples effective_num = 1.0 - torch.pow(self.beta, self.class_counts) effective_num = (1.0 - self.beta) / effective_num return effective_num / torch.sum(effective_num) * len(self.class_counts) def forward(self, logits, targets): # Convert class weights to the appropriate device class_weights = self.class_weights.to(logits.device) weights = class_weights[targets] # Compute cross entropy loss ce_loss = F.cross_entropy(logits, targets, reduction='none') # Apply weights weighted_loss = weights * ce_loss return torch.mean(weighted_loss) # Example usage # Assuming 3 classes with respective counts: [1000, 100, 10] class_counts = torch.tensor([1000, 100, 10], dtype=torch.float32) beta = 0.99 # Create the loss function cb_loss = ClassBalancedLoss(beta, class_counts) # Fake logits and targets for demonstration logits = torch.randn(5, 3) # 5 samples, 3 classes targets = torch.tensor([0, 1, 2, 1, 0]) # Ground truth labels # Calculate the class-balanced loss loss = cb_loss(logits, targets) print(f"Class-Balanced Loss: {loss.item()}") ``` ### Explanation of Code - **ClassBalancedLoss Module**: The `ClassBalancedLoss` class extends `nn.Module` to define the custom loss function. It takes a parameter `beta` and the class counts to compute the class weights. - **compute_class_weights()**: This function calculates the effective number of samples for each class, and computes class weights based on the effective numbers. - **forward()**: This function calculates the standard cross-entropy loss and multiplies it by the computed class weights to balance the loss. ### 5. Summary of Class-Balanced Loss - **Why Use It**: To address the problem of class imbalance in datasets, which can lead to biased models that perform poorly on minority classes. - **How It Works**: It uses a weighting scheme based on the effective number of samples in each class, giving more weight to the minority classes so that their contribution to the overall loss is amplified. - **Mathematics Behind It**: It relies on the concept of diminishing returns in terms of the number of samples for each class, which is quantified through the effective number formula. - **Use Cases**: Useful in applications where the dataset has a long-tail distribution or where certain classes are far less frequent, such as medical diagnosis, fraud detection, or rare event prediction. The advantage of class-balanced loss is that it explicitly addresses the imbalance without modifying the data distribution itself (as done with oversampling/undersampling), and without introducing noise or overfitting risks that sometimes occur with data-level techniques.