Blog cover

Cross Entropy from Scratch!

Duplicate Pytorch’s implementation of Cross Entropy loss

Introduction

This blog will cover a very simple exercise, coding Cross entropy loss from scratch such that the loss value for a random (input,target) pair from our implementation is the same as the loss from PyTorch’s cross Entropy implementation

notion image

Pseudo code

  • We will have a input logits of shape (Batch, number of classes) . For our sample we have a batch size of 3 and 5 classes.
  • The targets will be integers of shape (batch, 1) . They will be integers between [0, num_classes -1]
  • logits can be negative as well and in any scale ! not necessary between 0 and 1
  • Step 1: We have to apply softmax for each row ( for each sample in a batch)
  • Step 2: Choose the likelihood corresponding to target class and apply
    • -log(probability_of_correct_class) on it

  • Step 3: take the mean of all log probabilities in a batch
 

Avoid Integer overflow !

Note that to calculate softmax we have to take the exponent of the logits. For logits as big as 1000, the computer doesn’t have enough bits to store this large number! To avoid this we do a neat trick

In simple Words, subtracting a constant from the logits , won’t change their softmax probability

notion image
notion image

As you can see, the e^c terms cancel out in both numerator and denominator, proving that subtracting any constant (in our case, the maximum value) from all logits gives us the same softmax probabilities, but in a numerically stable way.

The key step is that e^{x-c} = e^x/e^c, which lets the common factor e^c cancel out in the final division.

Code

import torch
import numpy as np

def cross_entropy_loss(predictions, targets):
    """
    Custom implementation of cross entropy loss
    
    Args:
        predictions: Raw model output logits of shape (batch_size, num_classes)
        targets: Ground truth labels of shape (batch_size,)
        
    Returns:
        loss: Mean cross entropy loss across the batch
    """
    # Get batch size
    batch_size = predictions.shape[0]
    
    # Apply softmax to get probabilities
    exp_preds = np.exp(predictions - np.max(predictions, axis=1, keepdims=True))
    softmax_preds = exp_preds / np.sum(exp_preds, axis=1, keepdims=True)
    
    # Get predicted probability for the correct class
    correct_class_probs = softmax_preds[range(batch_size), targets]
    
    # Calculate negative log likelihood
    loss = -np.log(correct_class_probs + 1e-7)  # Add small epsilon for numerical stability
    
    # Return mean loss
    return np.mean(loss)

# Test the implementation
if __name__ == "__main__":
    # Generate sample data
    np.random.seed(42)
    batch_size = 3
    num_classes = 4
    
    # Create random logits and targets
    logits = np.random.randn(batch_size, num_classes)
    targets = np.random.randint(0, num_classes, size=batch_size)
    
    # Calculate loss using our implementation
    custom_loss = cross_entropy_loss(logits, targets)
    
    # Calculate loss using PyTorch
    torch_logits = torch.FloatTensor(logits)
    torch_targets = torch.LongTensor(targets)
    torch_loss = torch.nn.CrossEntropyLoss()(torch_logits, torch_targets)
    
    print("Input logits:")
    print(logits)
    print("\nTarget labels:", targets)
    print("\nCustom implementation loss:", custom_loss)
    print("PyTorch implementation loss:", torch_loss.item())
    print("\nDifference:", abs(custom_loss - torch_loss.item()))