MRCpy.pytorch.mgce.loss.mgce_loss

class MRCpy.pytorch.mgce.loss.mgce_loss(num_classes, beta=1.02)[source]

Margin Loss for Minimax Generalized Cross-Entropy (MGCE) Classification. See mgce_clf for the full \(\alpha\)-loss definition and the minimax framework.

This class implements the \(\alpha\)-loss (\(\ell_\beta\)) used in the Minimax Generalized Cross-Entropy (MGCE) framework. The class provides a parametric family of loss functions indexed by \(\beta \geq 1\) such that \(\beta = \alpha / (\alpha-1)\) generalize the standard cross-entropy (log-loss):

  • \(\beta = 1\): the convex 0-1 loss (see [2])

  • \(\beta \to \infty\): the log-loss (cross-entropy)

The name “generalized cross-entropy” reflects that the \(\alpha\)-loss family extends the standard cross-entropy by allowing the practitioner to tune \(\beta\) to control the trade-off between robustness (low \(\beta\)) and smoothness (high \(\beta\)).

This class computes gradients and loss values that can be directly integrated with PyTorch training loops.

Parameters:
num_classesint

Number of classes in the classification problem. Must be >= 2.

betafloat, default=1.02

The \(\beta\) parameter controlling the loss behavior. Values in the range \([1, 11]\) sufficiently interpolate between all possible loss functions:

  • beta = 1.0: 0-1 loss (non-smooth).

  • beta 11: approximately the log-loss (cross-entropy).

  • Intermediate values smoothly trade off between robustness (low \(\beta\)) and smoothness (high \(\beta\)).

References

[1]

Bondugula, K., Mazuelas, S., Pérez, A., & Liu, A. (2026). Minimax Generalized Cross-Entropy. In Proceedings of the International Conference on Artificial Intelligence and Statistics (AISTATS).

[2]

Mazuelas, S., Shen, Y., & Pérez, A. (2022). Generalized Maximum Entropy for Supervised Classification. IEEE Transactions on Information Theory, 68(4), 2530-2550.

Examples

Basic usage in a training loop:

>>> import torch
>>> from MRCpy.pytorch.mgce.loss import mgce_loss
>>> 
>>> # Initialize with beta=1.4
>>> loss_fn = mgce_loss(num_classes=3, beta=1.4)
>>> 
>>> # In training loop
>>> logits = model(inputs)  # Shape: (batch_size, num_classes)
>>> labels = targets       # Shape: (batch_size,)
>>> 
>>> # Compute gradients for backpropagation
>>> gradients, loss_value = loss_fn.get_gradient(logits, labels)
>>> 
>>> # Backpropagate using computed gradients
>>> logits.backward(gradients)

For validation/inference:

>>> # Compute loss and probabilities
>>> loss_val, probabilities = loss_fn.get_loss_value(logits, labels, reg_val=0.0)
>>> predicted_classes = torch.argmax(probabilities, dim=1)
Attributes:
num_classesint

Number of classes in the classification problem.

betafloat

The parameter for the loss function.

Methods

get_gradient(logits, labels)

Compute gradients for the \(\alpha\)-loss function.

get_loss_value(logits, labels, reg_val)

Compute loss value and probability predictions.

get_probs(logits)

Compute class probabilities using the \(\alpha\)-loss framework.

__init__(num_classes, beta=1.02)[source]

Initialize self. See help(type(self)) for accurate signature.

get_gradient(logits, labels)[source]

Compute gradients for the \(\alpha\)-loss function.

This method computes the gradients with respect to the logits for backpropagation. The gradients are computed using either the 0-1 loss formulation (when beta=1) or the smooth \(\alpha\)-loss approximation.

Parameters:
logitstorch.Tensor of shape (batch_size, num_classes)

Raw model outputs (logits) before applying softmax or other activation functions.

labelstorch.Tensor of shape (batch_size,)

True class labels as integer indices in range [0, num_classes-1].

Returns:
gradientstorch.Tensor of shape (batch_size, num_classes)

Gradients with respect to the input logits. Can be used directly with logits.backward(gradients) for backpropagation.

loss_valuetorch.Tensor (scalar)

The computed loss value for the current batch.

Examples

>>> loss_fn = mgce_loss(num_classes=3, beta=1.4)
>>> logits = torch.randn(32, 3, requires_grad=True)
>>> labels = torch.randint(0, 3, (32,))
>>> 
>>> gradients, loss = loss_fn.get_gradient(logits, labels)
>>> logits.backward(gradients)  # Backpropagate
get_loss_value(logits, labels, reg_val)[source]

Compute loss value and probability predictions.

This method computes the \(\alpha\)-loss value and the corresponding probability predictions. It is typically used during validation or when you need both the loss and the predicted probabilities.

Parameters:
logitstorch.Tensor of shape (batch_size, num_classes)

Raw model outputs (logits) before applying softmax.

labelstorch.Tensor of shape (batch_size,)

True class labels as integer indices in range [0, num_classes-1].

reg_valfloat or torch.Tensor (scalar)

Regularization value to add to the loss. Typically the L1 or L2 regularization term from the model parameters.

Returns:
total_losstorch.Tensor (scalar)

The computed loss value including regularization.

probabilitiestorch.Tensor of shape (batch_size, num_classes)

Predicted class probabilities. Each row sums to 1.0 and represents the probability distribution over classes for the corresponding sample.

Examples

>>> loss_fn = mgce_loss(num_classes=3, beta=1.4)
>>> logits = torch.randn(32, 3)
>>> labels = torch.randint(0, 3, (32,))
>>> reg_val = 0.01
>>> 
>>> loss, probs = loss_fn.get_loss_value(logits, labels, reg_val)
>>> predicted_classes = torch.argmax(probs, dim=1)
get_probs(logits)[source]

Compute class probabilities using the \(\alpha\)-loss framework.

This method computes class probabilities for given logits without computing the loss value, making it efficient for inference.

For beta=1 (0-1 loss case), uses the psi function with batched computation. For beta>1 (smooth approximation), uses the bisection solver to find optimal nu values.

Parameters:
logitstorch.Tensor of shape (batch_size, num_classes)

Raw model outputs (logits) before applying softmax or other activation functions.

Returns:
probabilitiestorch.Tensor of shape (batch_size, num_classes)

Predicted class probabilities for each input sample. Each row sums to 1.0. Probabilities are computed using the \(\alpha\)-loss framework rather than standard softmax.

See also

get_loss_value

Compute both loss value and probabilities.

get_gradient

Compute gradients for backpropagation.

Examples

>>> loss_fn = mgce_loss(num_classes=3, beta=1.4)
>>> logits = torch.randn(32, 3)
>>> 
>>> probabilities = loss_fn.get_probs(logits)
>>> predicted_classes = torch.argmax(probabilities, dim=1)
>>> confidence_scores = torch.max(probabilities, dim=1)[0]