MRCpy.pytorch.mgce.loss.mgce_loss¶
- class MRCpy.pytorch.mgce.loss.mgce_loss(num_classes, beta=1.02)[source]¶
Margin Loss for Minimax Generalized Cross-Entropy (MGCE) Classification. See
mgce_clffor the full \(\alpha\)-loss definition and the minimax framework.This class implements the \(\alpha\)-loss (\(\ell_\beta\)) used in the Minimax Generalized Cross-Entropy (MGCE) framework. The class provides a parametric family of loss functions indexed by \(\beta \geq 1\) such that \(\beta = \alpha / (\alpha-1)\) generalize the standard cross-entropy (log-loss):
\(\beta = 1\): the convex 0-1 loss (see [2])
\(\beta \to \infty\): the log-loss (cross-entropy)
The name “generalized cross-entropy” reflects that the \(\alpha\)-loss family extends the standard cross-entropy by allowing the practitioner to tune \(\beta\) to control the trade-off between robustness (low \(\beta\)) and smoothness (high \(\beta\)).
This class computes gradients and loss values that can be directly integrated with PyTorch training loops.
- Parameters:
- num_classesint
Number of classes in the classification problem. Must be >= 2.
- betafloat, default=1.02
The \(\beta\) parameter controlling the loss behavior. Values in the range \([1, 11]\) sufficiently interpolate between all possible loss functions:
beta = 1.0: 0-1 loss (non-smooth).beta ≈ 11: approximately the log-loss (cross-entropy).Intermediate values smoothly trade off between robustness (low \(\beta\)) and smoothness (high \(\beta\)).
References
[1]Bondugula, K., Mazuelas, S., Pérez, A., & Liu, A. (2026). Minimax Generalized Cross-Entropy. In Proceedings of the International Conference on Artificial Intelligence and Statistics (AISTATS).
[2]Mazuelas, S., Shen, Y., & Pérez, A. (2022). Generalized Maximum Entropy for Supervised Classification. IEEE Transactions on Information Theory, 68(4), 2530-2550.
Examples
Basic usage in a training loop:
>>> import torch >>> from MRCpy.pytorch.mgce.loss import mgce_loss >>> >>> # Initialize with beta=1.4 >>> loss_fn = mgce_loss(num_classes=3, beta=1.4) >>> >>> # In training loop >>> logits = model(inputs) # Shape: (batch_size, num_classes) >>> labels = targets # Shape: (batch_size,) >>> >>> # Compute gradients for backpropagation >>> gradients, loss_value = loss_fn.get_gradient(logits, labels) >>> >>> # Backpropagate using computed gradients >>> logits.backward(gradients)
For validation/inference:
>>> # Compute loss and probabilities >>> loss_val, probabilities = loss_fn.get_loss_value(logits, labels, reg_val=0.0) >>> predicted_classes = torch.argmax(probabilities, dim=1)
- Attributes:
- num_classesint
Number of classes in the classification problem.
- betafloat
The parameter for the loss function.
Methods
get_gradient(logits, labels)Compute gradients for the \(\alpha\)-loss function.
get_loss_value(logits, labels, reg_val)Compute loss value and probability predictions.
get_probs(logits)Compute class probabilities using the \(\alpha\)-loss framework.
- __init__(num_classes, beta=1.02)[source]¶
Initialize self. See help(type(self)) for accurate signature.
- get_gradient(logits, labels)[source]¶
Compute gradients for the \(\alpha\)-loss function.
This method computes the gradients with respect to the logits for backpropagation. The gradients are computed using either the 0-1 loss formulation (when beta=1) or the smooth \(\alpha\)-loss approximation.
- Parameters:
- logitstorch.Tensor of shape (batch_size, num_classes)
Raw model outputs (logits) before applying softmax or other activation functions.
- labelstorch.Tensor of shape (batch_size,)
True class labels as integer indices in range [0, num_classes-1].
- Returns:
- gradientstorch.Tensor of shape (batch_size, num_classes)
Gradients with respect to the input logits. Can be used directly with logits.backward(gradients) for backpropagation.
- loss_valuetorch.Tensor (scalar)
The computed loss value for the current batch.
Examples
>>> loss_fn = mgce_loss(num_classes=3, beta=1.4) >>> logits = torch.randn(32, 3, requires_grad=True) >>> labels = torch.randint(0, 3, (32,)) >>> >>> gradients, loss = loss_fn.get_gradient(logits, labels) >>> logits.backward(gradients) # Backpropagate
- get_loss_value(logits, labels, reg_val)[source]¶
Compute loss value and probability predictions.
This method computes the \(\alpha\)-loss value and the corresponding probability predictions. It is typically used during validation or when you need both the loss and the predicted probabilities.
- Parameters:
- logitstorch.Tensor of shape (batch_size, num_classes)
Raw model outputs (logits) before applying softmax.
- labelstorch.Tensor of shape (batch_size,)
True class labels as integer indices in range [0, num_classes-1].
- reg_valfloat or torch.Tensor (scalar)
Regularization value to add to the loss. Typically the L1 or L2 regularization term from the model parameters.
- Returns:
- total_losstorch.Tensor (scalar)
The computed loss value including regularization.
- probabilitiestorch.Tensor of shape (batch_size, num_classes)
Predicted class probabilities. Each row sums to 1.0 and represents the probability distribution over classes for the corresponding sample.
Examples
>>> loss_fn = mgce_loss(num_classes=3, beta=1.4) >>> logits = torch.randn(32, 3) >>> labels = torch.randint(0, 3, (32,)) >>> reg_val = 0.01 >>> >>> loss, probs = loss_fn.get_loss_value(logits, labels, reg_val) >>> predicted_classes = torch.argmax(probs, dim=1)
- get_probs(logits)[source]¶
Compute class probabilities using the \(\alpha\)-loss framework.
This method computes class probabilities for given logits without computing the loss value, making it efficient for inference.
For beta=1 (0-1 loss case), uses the psi function with batched computation. For beta>1 (smooth approximation), uses the bisection solver to find optimal nu values.
- Parameters:
- logitstorch.Tensor of shape (batch_size, num_classes)
Raw model outputs (logits) before applying softmax or other activation functions.
- Returns:
- probabilitiestorch.Tensor of shape (batch_size, num_classes)
Predicted class probabilities for each input sample. Each row sums to 1.0. Probabilities are computed using the \(\alpha\)-loss framework rather than standard softmax.
See also
get_loss_valueCompute both loss value and probabilities.
get_gradientCompute gradients for backpropagation.
Examples
>>> loss_fn = mgce_loss(num_classes=3, beta=1.4) >>> logits = torch.randn(32, 3) >>> >>> probabilities = loss_fn.get_probs(logits) >>> predicted_classes = torch.argmax(probabilities, dim=1) >>> confidence_scores = torch.max(probabilities, dim=1)[0]