"""Modular loss system for flexible autoencoders.
This module provides a flexible, extensible system for composing different loss
functions in autoencoder training. It supports standard reconstruction losses
as well as advanced disentanglement objectives like β-VAE, TC-VAE, and Factor-VAE.
"""
from abc import ABC, abstractmethod
from typing import Dict, Type, Optional, Any, Tuple
import logging
import numpy as np
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
except (ImportError, OSError):
# OSError catches DLL loading issues on Windows
torch = None
nn = None
F = None
[docs]
class AELoss(ABC):
"""Base class for all autoencoder loss components.
Each loss component computes a specific objective (e.g., reconstruction,
disentanglement, sparsity) and has an associated weight for balancing
multiple objectives."""
[docs]
def __init__(self, weight: float = 1.0, **kwargs):
"""Initialize loss component.
Parameters
----------
weight : float, default=1.0
Weight for this loss component when combining multiple losses.
**kwargs
Additional parameters specific to each loss type."""
self._weight = weight
self.kwargs = kwargs
[docs]
@abstractmethod
def compute(
self, code: torch.Tensor, recon: torch.Tensor, inputs: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the loss value.
Parameters
----------
code : torch.Tensor
Latent representation, shape (batch_size, code_dim).
recon : torch.Tensor
Reconstructed outputs, shape (batch_size, input_dim).
inputs : torch.Tensor
Original inputs, shape (batch_size, input_dim).
**kwargs
Additional tensors/parameters needed by specific losses
(e.g., mu and log_var for VAE losses).
Returns
-------
torch.Tensor
Scalar loss value."""
pass
@property
def weight(self) -> float:
"""
Get the loss weight for balancing multiple objectives.
The weight determines the relative importance of this loss component
when combined with other losses in a multi-objective optimization.
Higher weights increase the influence of this loss on the total loss.
Returns
-------
float
Current weight value for this loss component."""
return self._weight
@weight.setter
def weight(self, value: float):
"""
Set the loss weight for multi-objective optimization.
Parameters
----------
value : float
New weight value. Should be non-negative. A weight of 0 effectively
disables this loss component.
Notes
-----
Changing the weight during training can be used for curriculum learning
or adaptive loss balancing strategies.
Raises
------
ValueError
If value is negative."""
if value < 0:
raise ValueError(f"Weight must be non-negative, got {value}")
self._weight = value
class ClassificationLoss(AELoss):
"""Classification loss for supervised dimensionality reduction.
Trains a classifier on the latent code to predict class labels,
forcing the latent representation to preserve class-relevant information.
Parameters
----------
num_classes : int
Number of target classes.
code_dim : int
Dimension of the latent code (auto-injected by ModularAutoencoder).
hidden_dim : int or None, default=None
If None, uses a linear classifier. If int, uses a 2-layer MLP
(code_dim -> hidden_dim -> num_classes) with LeakyReLU activation.
weight : float, default=1.0
Weight for this loss component.
"""
def __init__(self, num_classes: int, code_dim: int, hidden_dim: int = None,
weight: float = 1.0, **kwargs):
super().__init__(weight=weight, **kwargs)
if hidden_dim is None:
self.classifier = nn.Linear(code_dim, num_classes)
else:
self.classifier = nn.Sequential(
nn.Linear(code_dim, hidden_dim),
nn.LeakyReLU(),
nn.Linear(hidden_dim, num_classes),
)
self.criterion = nn.CrossEntropyLoss()
def compute(self, code, recon, inputs, labels=None, **kwargs):
"""Compute cross-entropy loss between predicted and true labels.
Parameters
----------
code : torch.Tensor
Latent representation, shape (batch_size, code_dim).
labels : torch.Tensor, optional
Integer class labels, shape (batch_size,). If None, returns 0.
Returns
-------
torch.Tensor
Scalar cross-entropy loss."""
if labels is None:
return torch.tensor(0.0, device=code.device)
logits = self.classifier(code)
return self.criterion(logits, labels.to(code.device))
def parameters(self):
"""Return classifier parameters for optimizer."""
return self.classifier.parameters()
[docs]
class LossRegistry:
"""Registry for dynamically managing loss components.
Provides a centralized system for registering and creating loss functions
for autoencoders. Supports dynamic registration of custom losses and
maintains a catalog of available loss types.
The registry pattern allows for easy extension with new loss types without
modifying existing code. All registered losses must inherit from AELoss.
Attributes
----------
losses : Dict[str, Type[AELoss]]
Mapping from loss names to their class types.
Examples
--------
>>> registry = LossRegistry()
>>> # Create a standard reconstruction loss
>>> recon_loss = registry.create('reconstruction', loss_type='mse')
>>>
>>> # Register a custom loss
>>> class MyCustomLoss(AELoss):
... def compute(self, code, recon, inputs, **kwargs):
... return torch.tensor(0.0)
>>> registry.register('custom', MyCustomLoss)
>>> custom_loss = registry.create('custom', weight=2.0)"""
[docs]
def __init__(self):
"""
Initialize the registry with default loss types.
Creates an empty loss registry and populates it with standard loss
functions commonly used in autoencoders:
- 'reconstruction': Standard reconstruction loss (MSE/BCE)
- 'correlation': Decorrelation loss for latent features
- 'kl': KL divergence for variational autoencoders
- 'activity': L1/L2 activity regularization
- 'jacobian': Jacobian regularization for contractive autoencoders
The registry can be extended with custom losses after initialization."""
self.losses: Dict[str, Type[AELoss]] = {}
self._logger = logging.getLogger(self.__class__.__name__)
self._register_defaults()
def _register_defaults(self):
"""Register default loss implementations.
Registers the standard set of loss functions available out-of-the-box:
- reconstruction: Standard MSE/BCE reconstruction loss
- correlation: Decorrelation loss for latent features
- orthogonality: Minimize correlation with external data
- beta_vae: β-VAE loss for disentanglement
- tc_vae: Total Correlation VAE loss
- factor_vae: Factor-VAE with discriminator
- sparse: Sparsity-inducing loss
- contractive: Contractive autoencoder loss
- mmd: Maximum Mean Discrepancy loss"""
# Register all default losses
defaults = {
"reconstruction": ReconstructionLoss,
"correlation": CorrelationLoss,
"orthogonality": OrthogonalityLoss,
"beta_vae": BetaVAELoss,
"tc_vae": TCVAELoss,
"factor_vae": FactorVAELoss,
"sparse": SparsityLoss,
"contractive": ContractiveLoss,
"mmd": MMDLoss,
"classification": ClassificationLoss,
}
for name, loss_class in defaults.items():
self.register(name, loss_class)
[docs]
def register(self, name: str, loss_class: Type[AELoss]):
"""Register a new loss type.
Parameters
----------
name : str
Name identifier for the loss.
loss_class : Type[AELoss]
Loss class (must inherit from AELoss).
Raises
------
ValueError
If loss_class does not inherit from AELoss."""
if not issubclass(loss_class, AELoss):
raise ValueError(f"Loss class must inherit from AELoss, got {loss_class}")
self.losses[name] = loss_class
self._logger.debug(f"Registered loss '{name}' -> {loss_class.__name__}")
[docs]
def create(self, name: str, **kwargs) -> AELoss:
"""Create a loss instance by name.
Parameters
----------
name : str
Registered name of the loss.
**kwargs
Parameters to pass to the loss constructor.
Returns
-------
AELoss
Instantiated loss component.
Raises
------
ValueError
If the loss name is not registered."""
if name not in self.losses:
raise ValueError(f"Unknown loss '{name}'. Available: {list(self.losses.keys())}")
return self.losses[name](**kwargs)
# Standard Losses
[docs]
class ReconstructionLoss(AELoss):
"""
Standard reconstruction loss for autoencoders (MSE or BCE).
The reconstruction loss measures how well the autoencoder can reconstruct
the input data from its latent representation. This is the primary loss
component for training autoencoders.
Supports two loss types:
- MSE (Mean Squared Error): For continuous data reconstruction
- BCE (Binary Cross-Entropy): For binary data or data in [0,1] range
The loss encourages the decoder to accurately reconstruct inputs from
the encoded representations, ensuring information preservation.
Parameters
----------
loss_type : {'mse', 'bce'}, default='mse'
Type of reconstruction loss to use.
weight : float, default=1.0
Weight for this loss component in multi-objective optimization.
Examples
--------
>>> # For continuous data
>>> loss = ReconstructionLoss(loss_type='mse', weight=1.0)
>>>
>>> # For binary or probabilistic data
>>> loss = ReconstructionLoss(loss_type='bce', weight=2.0)"""
[docs]
def __init__(self, loss_type: str = "mse", weight: float = 1.0):
"""Initialize reconstruction loss.
Parameters
----------
loss_type : str, default="mse"
Type of reconstruction loss ("mse" or "bce").
weight : float, default=1.0
Loss weight.
Raises
------
ValueError
If loss_type is not 'mse' or 'bce'."""
super().__init__(weight=weight)
self.loss_type = loss_type
if loss_type == "mse":
self.criterion = nn.MSELoss(reduction="mean")
elif loss_type == "bce":
self.criterion = nn.BCELoss(reduction="mean")
else:
raise ValueError(f"Unknown loss type: {loss_type}")
[docs]
def compute(self, code, recon, inputs, **kwargs):
"""Compute reconstruction loss between input and reconstruction.
Parameters
----------
code : torch.Tensor
Latent representation (unused for reconstruction loss).
recon : torch.Tensor
Reconstructed data, shape (batch_size, n_features).
inputs : torch.Tensor
Original input data, shape (batch_size, n_features).
**kwargs
Additional arguments (unused).
Returns
-------
torch.Tensor
Scalar loss value. MSE for continuous data or BCE for binary data,
depending on loss_type specified in __init__."""
return self.criterion(recon, inputs)
[docs]
class CorrelationLoss(AELoss):
"""
Correlation loss to encourage decorrelated latent features.
This loss minimizes correlations between different dimensions of the latent
code, encouraging the autoencoder to learn a disentangled representation
where each latent dimension captures independent factors of variation.
The loss is computed as the sum of squared off-diagonal elements of the
correlation matrix of latent codes. A fully decorrelated representation
would have a diagonal correlation matrix (identity matrix after normalization).
This regularization is particularly useful for:
- Learning interpretable latent representations
- Preventing redundancy in latent dimensions
- Improving generalization by reducing overfitting
Mathematical formulation:
``L_corr = (1/P) * sum_{i≠j} abs(corr(z_i, z_j))``
where z_i is the i-th dimension of the latent code across the batch,
and P is the number of off-diagonal pairs.
Parameters
----------
weight : float, default=1.0
Weight for this loss component. Higher values enforce stronger
decorrelation but may harm reconstruction quality.
Notes
-----
The correlation is computed across the batch dimension, so larger batch
sizes provide more accurate correlation estimates. Requires batch_size >= 2."""
[docs]
def __init__(self, weight: float = 1.0):
"""Initialize correlation loss.
Parameters
----------
weight : float, default=1.0
Loss weight."""
super().__init__(weight=weight)
[docs]
def compute(self, code, recon, inputs, **kwargs):
"""Compute average pairwise correlation amplitude in latent code.
Encourages decorrelated latent features by penalizing correlations
between different dimensions of the latent representation.
Parameters
----------
code : torch.Tensor
Latent representation, shape (batch_size, code_dim).
recon : torch.Tensor
Reconstructed data (unused).
inputs : torch.Tensor
Original input data (unused).
**kwargs
Additional arguments (unused).
Returns
-------
torch.Tensor
Average absolute correlation across all pairs of latent dimensions.
Returns 0 if code_dim = 1 or batch_size < 2.
Notes
-----
Requires batch_size >= 2 for correlation computation. Returns 0 for
single-sample batches."""
# Handle single feature case
if code.shape[1] == 1:
return torch.tensor(0.0, device=code.device)
# Handle single sample case (corrcoef requires at least 2 samples)
if code.shape[0] < 2:
return torch.tensor(0.0, device=code.device)
# Transpose to have features as rows for corrcoef
code_t = code.T
# Compute correlation matrix
corr = torch.corrcoef(code_t)
# Average absolute pairwise correlation (excluding diagonal)
n_features = corr.shape[0]
off_diagonal_sum = torch.sum(torch.abs(corr)) - n_features
n_pairs = n_features * (n_features - 1)
avg_correlation = off_diagonal_sum / n_pairs if n_pairs > 0 else 0
return avg_correlation
[docs]
class OrthogonalityLoss(AELoss):
"""Orthogonality loss to minimize correlation with external data (MI proxy).
FUTURE: Replace correlation-based approach with proper mutual information
estimation (e.g., using GCMI or KSG estimators from information module).
Current implementation uses correlation as a crude proxy for MI."""
[docs]
def __init__(self, external_data: Optional[np.ndarray] = None, weight: float = 1.0):
"""Initialize orthogonality loss.
Parameters
----------
external_data : np.ndarray, optional
External data to minimize correlation with, shape (n_features, n_samples).
weight : float, default=1.0
Loss weight."""
super().__init__(weight=weight)
self.external_data = external_data
self._external_tensor = None
[docs]
def compute(self, code, recon, inputs, indices=None, **kwargs):
"""Compute correlation between latent code and external data.
Used as a proxy for mutual information minimization. Encourages the
latent representation to be orthogonal (uncorrelated) with provided
external variables.
FUTURE: Implement proper MI computation using driada.information module
estimators instead of correlation-based approximation.
Parameters
----------
code : torch.Tensor
Latent representation, shape (batch_size, code_dim).
recon : torch.Tensor
Reconstructed data (unused).
inputs : torch.Tensor
Original input data (unused).
indices : torch.Tensor, optional
Batch indices to select corresponding external data columns.
If None, assumes first batch_size columns of external_data.
**kwargs
Additional arguments (unused).
Returns
-------
torch.Tensor
Average absolute correlation between latent code and external data.
Returns 0 if no external data provided.
Notes
-----
External data should have shape (n_features, n_samples) where n_samples
should be >= batch_size if indices is None."""
if self.external_data is None:
return torch.tensor(0.0, device=code.device)
# Get batch indices
if indices is None:
batch_size = code.shape[0]
indices = torch.arange(batch_size)
# Convert external data to tensor if needed
if self._external_tensor is None:
self._external_tensor = torch.tensor(self.external_data, dtype=torch.float32)
# Move to correct device
if self._external_tensor.device != code.device:
self._external_tensor = self._external_tensor.to(code.device)
# Get relevant external data for this batch
ext_batch = self._external_tensor[:, indices].T
# Compute correlation between code and external data
n_code = code.shape[1]
n_ext = ext_batch.shape[1]
# Concatenate for correlation computation
combined = torch.cat([code, ext_batch], dim=1)
corr = torch.corrcoef(combined.T)
# Extract cross-correlation block
cross_corr = corr[:n_code, n_code:]
# Average absolute correlation
avg_correlation = torch.mean(torch.abs(cross_corr))
return avg_correlation
# VAE-based Disentanglement Losses
[docs]
class BetaVAELoss(AELoss):
"""β-VAE loss for disentanglement via increased KL penalty.
Implements the β-VAE objective which modifies the standard VAE loss by
scaling the KL divergence term with a factor β > 1. This encourages the
model to learn disentangled representations where each latent dimension
captures at most one factor of variation.
The full β-VAE loss (when combined with reconstruction) is:
L = Reconstruction + β * KL(q(z|x)||p(z))
Notes
-----
- β = 1 recovers the standard VAE
- β > 1 encourages disentanglement
- Too high β can hurt reconstruction quality
- Typical values: β ∈ [4, 10] for disentanglement tasks
References
----------
Higgins, I., et al. (2017). β-VAE: Learning Basic Visual Concepts with
a Constrained Variational Framework. ICLR 2017."""
[docs]
def __init__(self, beta: float = 4.0, weight: float = 1.0):
"""Initialize β-VAE loss.
Parameters
----------
beta : float, default=4.0
Beta parameter controlling KL penalty strength.
Must be positive for valid disentanglement.
weight : float, default=1.0
Loss weight.
Raises
------
ValueError
If beta <= 0."""
super().__init__(weight=weight)
if beta <= 0:
raise ValueError(f"Beta must be positive for disentanglement, got {beta}")
self.beta = beta
[docs]
def compute(self, code, recon, inputs, mu=None, log_var=None, **kwargs):
"""Compute β-weighted KL divergence loss.
Calculates the KL divergence between the learned posterior q(z|x) and
the standard normal prior p(z) = N(0,I), scaled by the β parameter.
Higher β values encourage greater disentanglement at the cost of
reconstruction quality.
Parameters
----------
code : torch.Tensor
Latent representation (unused, included for interface consistency).
recon : torch.Tensor
Reconstructed outputs (unused).
inputs : torch.Tensor
Original inputs (unused).
mu : torch.Tensor
Mean of the approximate posterior, shape (batch_size, latent_dim).
log_var : torch.Tensor
Log variance of the approximate posterior, shape (batch_size, latent_dim).
**kwargs
Additional arguments (unused).
Returns
-------
torch.Tensor
β-weighted KL divergence loss, averaged over the batch.
Raises
------
ValueError
If mu or log_var are not provided.
Notes
-----
The KL divergence for a Gaussian posterior is:
KL(q(z|x)||p(z)) = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
where the sum is over latent dimensions.
References
----------
Higgins, I., et al. (2017). β-VAE: Learning Basic Visual Concepts with
a Constrained Variational Framework. ICLR 2017."""
if mu is None or log_var is None:
raise ValueError("β-VAE loss requires mu and log_var")
# KL divergence from N(0,1)
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
kl_loss = kl_loss / mu.shape[0] # Average over batch
return self.beta * kl_loss
class TCVAELoss(AELoss):
"""TC-VAE loss with decomposed ELBO for better disentanglement.
Total Correlation VAE (TC-VAE) decomposes the KL divergence term of the
ELBO into three meaningful components:
1. Mutual Information between data and latent: I(x;z)
2. Total Correlation (TC) measuring dependence between latents: TC(z)
3. Dimension-wise KL divergence: Σ KL(q(z_i)||p(z_i))
By separately weighting these terms, TC-VAE can specifically target the
total correlation term to encourage factorial (independent) latent codes
while maintaining good reconstruction.
Notes
-----
The decomposition allows fine-grained control over different aspects:
- α controls information preserved about data
- β controls disentanglement (via total correlation)
- γ controls deviation from the prior
References
----------
Chen, T. Q., et al. (2018). Isolating Sources of Disentanglement in
Variational Autoencoders. NeurIPS 2018."""
def __init__(
self, alpha: float = 1.0, beta: float = 1.0, gamma: float = 1.0, weight: float = 1.0
):
"""Initialize TC-VAE loss.
Parameters
----------
alpha : float, default=1.0
Weight for mutual information term.
beta : float, default=1.0
Weight for total correlation term.
gamma : float, default=1.0
Weight for dimension-wise KL term.
weight : float, default=1.0
Overall loss weight.
Notes
-----
The decomposition is:
KL[q(z|x)||p(z)] = I(x;z) + TC(z) + Σ KL[q(z_i)||p(z_i)]
where I(x;z) is mutual information, TC(z) is total correlation."""
super().__init__(weight=weight)
self.alpha = alpha
self.beta = beta
self.gamma = gamma
def compute(self, code, recon, inputs, mu=None, log_var=None, **kwargs):
"""Compute TC-VAE loss with ELBO decomposition.
Estimates the three components of the decomposed ELBO using importance
sampling with the aggregate posterior as the proposal distribution.
Parameters
----------
code : torch.Tensor
Sampled latent representation z ~ q(z|x), shape (batch_size, latent_dim).
Used for computing log probabilities in the ELBO decomposition.
recon : torch.Tensor
Reconstructed outputs (unused).
inputs : torch.Tensor
Original inputs (unused).
mu : torch.Tensor
Mean of approximate posterior, shape (batch_size, latent_dim).
log_var : torch.Tensor
Log variance of approximate posterior, shape (batch_size, latent_dim).
**kwargs
Additional arguments (unused).
Returns
-------
torch.Tensor
Weighted sum of MI, TC, and dimension-wise KL components.
Raises
------
ValueError
If mu or log_var are not provided."""
if mu is None or log_var is None:
raise ValueError("TC-VAE loss requires mu and log_var")
batch_size = mu.shape[0]
latent_dim = mu.shape[1]
# Standard KL divergence
kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
# Log probability of code under the prior
log_pz = self._log_density_gaussian(code, torch.zeros_like(code), torch.zeros_like(code))
# Log probability of code under the posterior
log_qz_given_x = self._log_density_gaussian(code, mu, log_var)
# Log probability of code under the marginal posterior
# Approximate with minibatch-weighted sampling
log_qz = self._log_importance_weight_matrix(code, mu, log_var).logsumexp(dim=1) - np.log(
batch_size
)
# Decomposed terms (sum over latent dimensions)
# Note: MI = I(x;z) = log q(z|x) - log q(z) = KL[q(z|x)||q(z)]
mi = log_qz_given_x.sum(dim=1) - log_qz # Mutual Information
tc = log_qz - log_pz.sum(dim=1) # Total Correlation
dwkl = kl - mi - tc # Dimension-wise KL
# Weighted combination
loss = self.alpha * mi.mean() + self.beta * tc.mean() + self.gamma * dwkl.mean()
return loss
def _log_density_gaussian(self, x, mu, log_var):
"""Compute log probability under a Gaussian.
Calculates log p(x|μ,σ²) for a diagonal Gaussian distribution.
Parameters
----------
x : torch.Tensor
Points to evaluate, shape (..., latent_dim).
mu : torch.Tensor
Mean parameters, shape (..., latent_dim).
log_var : torch.Tensor
Log variance parameters, shape (..., latent_dim).
Returns
-------
torch.Tensor
Log probabilities, shape (..., latent_dim).
Notes
-----
Uses log variance for numerical stability. The log probability is:
log p(x|μ,σ²) = -0.5 * [log(2π) + log(σ²) + (x-μ)²/σ²]"""
normalization = -0.5 * (np.log(2 * np.pi) + log_var)
inv_var = torch.exp(-log_var)
log_density = normalization - 0.5 * ((x - mu).pow(2) * inv_var)
return log_density
def _log_importance_weight_matrix(self, batch, mu, log_var):
"""Compute importance weights for minibatch.
Calculates log importance weights for estimating the marginal posterior
q(z) from the conditional posteriors q(z|x) in a minibatch.
Parameters
----------
batch : torch.Tensor
Latent samples, shape (batch_size, latent_dim).
mu : torch.Tensor
Posterior means for all samples, shape (batch_size, latent_dim).
log_var : torch.Tensor
Posterior log variances, shape (batch_size, latent_dim).
Returns
-------
torch.Tensor
Log importance weight matrix, shape (batch_size, batch_size).
Element [i,j] contains log q(z_i|x_j).
Notes
-----
Used for estimating log q(z) ≈ log(1/N Σ_j q(z|x_j)) via importance
sampling, which is needed for the TC-VAE decomposition."""
batch_size, latent_dim = batch.shape
# Expand for broadcasting
batch = batch.unsqueeze(1) # (batch, 1, latent)
mu = mu.unsqueeze(0) # (1, batch, latent)
log_var = log_var.unsqueeze(0) # (1, batch, latent)
# Compute log densities
log_density = self._log_density_gaussian(batch, mu, log_var)
# Sum over latent dimensions
log_importance = log_density.sum(dim=2)
return log_importance
class FactorVAELoss(AELoss):
"""Factor-VAE loss using a discriminator for disentanglement.
Factor-VAE encourages disentanglement by using an adversarial discriminator
to estimate and minimize the total correlation (TC) in the latent space.
The discriminator is trained to distinguish between samples from q(z)
(the aggregate posterior) and the product of marginals Π q(z_i).
This approach avoids the sampling difficulties of TC-VAE while still
targeting the total correlation for disentanglement.
Notes
-----
The discriminator provides a density ratio estimate that can be used to
approximate the total correlation. The VAE is trained to minimize this
estimate (fool the discriminator) while the discriminator is trained to
maximize it (correctly classify real vs permuted codes).
IMPORTANT: The discriminator requires separate optimization in the training
loop. This loss only returns the TC penalty for the VAE.
FUTURE: Implement a complete FactorVAE training system that handles both
VAE and discriminator optimization, possibly as a separate Trainer class.
References
----------
Kim, H., & Mnih, A. (2018). Disentangling by Factorising. ICML 2018."""
def __init__(
self,
gamma: float = 10.0,
discriminator_dims: Optional[list] = None,
weight: float = 1.0,
device: Optional[torch.device] = None,
):
"""Initialize Factor-VAE loss.
Parameters
----------
gamma : float, default=10.0
Weight for the total correlation penalty.
discriminator_dims : list, optional
Hidden dimensions for discriminator network.
weight : float, default=1.0
Loss weight.
device : torch.device, optional
Device for discriminator."""
super().__init__(weight=weight)
self.gamma = gamma
self.discriminator_dims = discriminator_dims or [256, 256]
self.device = device or torch.device("cpu")
self.discriminator = None
self._latent_dim = None
def _build_discriminator(self, latent_dim: int):
"""Build discriminator network.
Creates a multi-layer perceptron to distinguish between samples from
the aggregate posterior q(z) and the product of marginals.
Parameters
----------
latent_dim : int
Dimension of the latent space.
Notes
-----
The discriminator architecture:
- Input: latent code (latent_dim)
- Hidden layers: specified by discriminator_dims
- Output: 2 classes (real vs permuted)
- Activation: LeakyReLU with dropout for regularization
The discriminator requires separate optimization in the training loop."""
layers = []
in_dim = latent_dim
for hidden_dim in self.discriminator_dims:
layers.extend([nn.Linear(in_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Dropout(0.2)])
in_dim = hidden_dim
layers.append(nn.Linear(in_dim, 2)) # Binary classification
self.discriminator = nn.Sequential(*layers)
self.discriminator.device = None # Will be set when needed
self._latent_dim = latent_dim
def compute(self, code, recon, inputs, **kwargs):
"""Compute Factor-VAE discriminator-based loss.
Trains a discriminator to estimate the total correlation and uses it
to penalize the encoder. The discriminator distinguishes between:
- Real: samples from q(z) (the aggregate posterior)
- Fake: samples where dimensions are independently permuted
Parameters
----------
code : torch.Tensor
Latent representation, shape (batch_size, latent_dim).
recon : torch.Tensor
Reconstructed outputs (unused).
inputs : torch.Tensor
Original inputs (unused).
**kwargs
Additional arguments (unused).
Returns
-------
torch.Tensor
Total correlation penalty estimated by the discriminator.
Notes
-----
The permutation breaks dependencies between latent dimensions,
approximating samples from the product of marginals Π q(z_i).
The discriminator's confidence in classifying real samples
indicates the strength of the total correlation."""
batch_size = code.shape[0]
latent_dim = code.shape[1]
# Build discriminator if needed
if self.discriminator is None:
self._build_discriminator(latent_dim)
# Move discriminator to correct device if needed
if self.discriminator.device != code.device:
self.discriminator = self.discriminator.to(code.device)
# Create permuted code (break factorial structure)
indices = torch.randperm(batch_size).to(code.device)
permuted_code = code[indices]
# Discriminator predictions
real_pred = self.discriminator(code)
fake_pred = self.discriminator(permuted_code.detach())
# Discriminator loss (standard GAN loss) - for reference, not returned
# disc_loss = F.cross_entropy(real_pred, torch.ones(batch_size, dtype=torch.long).to(code.device))
# disc_loss += F.cross_entropy(fake_pred, torch.zeros(batch_size, dtype=torch.long).to(code.device))
# Total correlation loss (fool discriminator)
tc_loss = F.cross_entropy(
real_pred, torch.zeros(batch_size, dtype=torch.long).to(code.device)
)
return self.gamma * tc_loss
# Regularization Losses
[docs]
class SparsityLoss(AELoss):
"""Sparsity loss to encourage sparse latent representations.
Implements a sparsity constraint on the latent activations by penalizing
the KL divergence between the average activation and a target sparsity level.
This encourages the model to use only a subset of latent dimensions for
each input, leading to more interpretable representations.
The loss is based on the KL divergence between:
- ρ̂: average activation of each latent unit (across the batch)
- ρ: target sparsity level (e.g., 0.05 for 5% average activation)
References
----------
Ng, A. (2011). Sparse autoencoder. CS294A Lecture notes, Stanford University."""
[docs]
def __init__(self, sparsity_target: float = 0.05, weight: float = 1.0):
"""Initialize sparsity loss.
Parameters
----------
sparsity_target : float, default=0.05
Target average activation level.
weight : float, default=1.0
Loss weight."""
super().__init__(weight=weight)
self.sparsity_target = sparsity_target
[docs]
def compute(self, code, recon, inputs, **kwargs):
"""Compute KL divergence between actual and target sparsity.
Calculates the sparsity penalty based on the average activation of each
latent unit compared to the target sparsity level.
Parameters
----------
code : torch.Tensor
Latent representation after activation (e.g., sigmoid),
shape (batch_size, latent_dim). Values should be in [0, 1].
recon : torch.Tensor
Reconstructed outputs (unused).
inputs : torch.Tensor
Original inputs (unused).
**kwargs
Additional arguments (unused).
Returns
-------
torch.Tensor
KL divergence between actual and target sparsity, summed over
all latent dimensions.
Notes
-----
For each latent unit j:
KL(ρ||ρ̂_j) = ρ log(ρ/ρ̂_j) + (1-ρ) log((1-ρ)/(1-ρ̂_j))
where ρ̂_j is the average activation of unit j over the batch.
Uses clamping for numerical stability to avoid log(0)."""
# Average activation per latent dimension
avg_activation = torch.mean(code, dim=0)
# Clamp for numerical stability
eps = 1e-6
avg_activation = torch.clamp(avg_activation, eps, 1 - eps)
rho = torch.clamp(torch.tensor(self.sparsity_target), eps, 1 - eps)
# KL divergence from target sparsity
kl = rho * torch.log(rho / avg_activation)
kl += (1 - rho) * torch.log((1 - rho) / (1 - avg_activation))
return torch.sum(kl)
class ContractiveLoss(AELoss):
"""Contractive loss for robust representations via Jacobian penalty.
Contractive autoencoders learn robust representations by penalizing the
Frobenius norm of the encoder's Jacobian. This encourages the encoder
to be locally constant, making the learned features insensitive to small
input perturbations.
The contractive penalty ||∂h/∂x||²_F encourages the encoder mapping to
contract the input space, except along the data manifold directions.
Notes
-----
Computing the full Jacobian can be expensive for high-dimensional inputs.
This implementation uses automatic differentiation to compute the penalty.
WARNING: Computational cost scales linearly with latent dimension.
Consider using stochastic approximations for latent_dim > 100.
References
----------
Rifai, S., et al. (2011). Contractive Auto-Encoders: Explicit Invariance
During Feature Extraction. ICML 2011."""
def __init__(self, lambda_c: float = 0.01, weight: float = 1.0):
"""Initialize contractive loss.
Parameters
----------
lambda_c : float, default=0.01
Contraction strength parameter.
weight : float, default=1.0
Loss weight."""
super().__init__(weight=weight)
self.lambda_c = lambda_c
def compute(self, code, recon, inputs, encoder=None, **kwargs):
"""Compute Frobenius norm of encoder Jacobian.
Calculates the contractive penalty as the squared Frobenius norm of
the Jacobian matrix ∂h/∂x, where h is the encoder output.
Parameters
----------
code : torch.Tensor
Latent representation h(x), shape (batch_size, latent_dim).
recon : torch.Tensor
Reconstructed outputs (unused).
inputs : torch.Tensor
Original inputs x, shape (batch_size, input_dim).
encoder : callable, optional
Encoder function/module. Required for this loss.
**kwargs
Additional arguments (unused).
Returns
-------
torch.Tensor
Frobenius norm of Jacobian, averaged over the batch.
Raises
------
ValueError
If encoder is not provided.
Notes
-----
WARNING: Current implementation is O(latent_dim) in computation time.
For high-dimensional latent spaces, consider using stochastic approximations
or alternative regularization methods.
For efficiency, we compute Tr(J^T J) = Σ_i ||∂h/∂x_i||² instead of
forming the full Jacobian matrix."""
if encoder is None:
raise ValueError("Contractive loss requires encoder module")
# Enable gradient computation for inputs
inputs = inputs.requires_grad_(True)
# Forward through encoder to get fresh activations with gradients
# Note: We ignore the passed 'code' and recompute to ensure gradient flow
h = encoder(inputs)
# Compute Jacobian via backpropagation
batch_size = inputs.shape[0]
jacobian_norm = 0
# WARNING: This loop is expensive for large latent_dim
for i in range(h.shape[1]):
# Gradient of i-th latent w.r.t. inputs
grad = torch.autograd.grad(h[:, i].sum(), inputs, retain_graph=True, create_graph=True)[
0
]
# Frobenius norm squared
jacobian_norm += torch.sum(grad**2) / batch_size
return self.lambda_c * jacobian_norm
class MMDLoss(AELoss):
"""Maximum Mean Discrepancy (MMD) loss for latent distribution matching.
MMD measures the distance between two distributions by comparing their
embeddings in a reproducing kernel Hilbert space (RKHS). It's used as an
alternative to KL divergence for matching distributions without requiring
explicit density calculations.
MMD is used to match the latent distribution q(z) to a prior p(z) by
minimizing the distance between their expected feature maps:
MMD²(p,q) = E[k(x,x')] + E[k(y,y')] - 2E[k(x,y)]
where k is a kernel function (typically RBF/Gaussian).
Notes
-----
- MMD = 0 if and only if the distributions are identical
- Choice of kernel bandwidth affects sensitivity to different scales
- MMD with Gaussian kernel can detect all distributional differences
- More computationally efficient than adversarial approaches
- Particularly useful for implicit generative models
For neural data:
- Useful when you want latent codes to follow a specific distribution
- Can encourage more interpretable latent spaces for neural dynamics
- Works well with non-Gaussian priors (e.g., mixture models)
WARNING: Uses fixed bandwidth (1.0). For better results, consider adaptive
bandwidth selection based on median heuristic or cross-validation.
References
----------
Gretton, A., et al. (2012). A kernel two-sample test.
Journal of Machine Learning Research, 13(1), 723-773.
Tolstikhin, I., et al. (2018). Wasserstein Auto-Encoders. ICLR 2018.
(Note: WAE paper uses MMD as an alternative to Wasserstein distance)"""
def __init__(self, mmd_weight: float = 1.0, weight: float = 1.0):
"""Initialize MMD loss.
Parameters
----------
mmd_weight : float, default=1.0
Weight for Maximum Mean Discrepancy term.
weight : float, default=1.0
Loss weight."""
super().__init__(weight=weight)
self.mmd_weight = mmd_weight
def compute(self, code, recon, inputs, **kwargs):
"""Compute MMD between latent distribution and prior.
Estimates the Maximum Mean Discrepancy between the encoded latent
distribution and a standard normal prior using an RBF kernel.
Parameters
----------
code : torch.Tensor
Latent representation, shape (batch_size, latent_dim).
recon : torch.Tensor
Reconstructed outputs (unused).
inputs : torch.Tensor
Original inputs (unused).
**kwargs
Additional arguments (unused).
Returns
-------
torch.Tensor
MMD loss value, scaled by mmd_weight.
Notes
-----
Uses empirical estimates:
- q(z): empirical distribution from encoder outputs
- p(z): samples from N(0,I) prior
The kernel bandwidth is fixed at 1.0 by default."""
batch_size = code.shape[0]
latent_dim = code.shape[1]
# Sample from prior
prior_sample = torch.randn_like(code)
# Compute MMD with RBF kernel
mmd = self._compute_mmd(code, prior_sample)
return self.mmd_weight * mmd
def _compute_mmd(self, x, y, kernel="rbf", bandwidth=1.0):
"""Compute Maximum Mean Discrepancy between two distributions.
Unbiased estimator of MMD² using U-statistics.
Parameters
----------
x : torch.Tensor
Samples from first distribution, shape (n_samples, dim).
y : torch.Tensor
Samples from second distribution, shape (n_samples, dim).
kernel : str, default='rbf'
Kernel type (only 'rbf' supported).
bandwidth : float, default=1.0
RBF kernel bandwidth parameter.
Returns
-------
torch.Tensor
MMD² estimate.
Notes
-----
Uses the unbiased U-statistic estimator that excludes diagonal terms:
MMD² = (1/(n(n-1)))Σᵢ≠ⱼ k(xᵢ,xⱼ) + (1/(m(m-1)))Σᵢ≠ⱼ k(yᵢ,yⱼ) - (2/(nm))Σᵢⱼ k(xᵢ,yⱼ)
This gives an unbiased estimate of the population MMD."""
n = x.shape[0]
m = y.shape[0]
# Compute kernel matrices
xx = self._kernel(x, x, kernel, bandwidth)
yy = self._kernel(y, y, kernel, bandwidth)
xy = self._kernel(x, y, kernel, bandwidth)
# For xx and yy, exclude diagonal and normalize properly
# Clone to avoid in-place modification issues
xx_no_diag = xx.clone()
yy_no_diag = yy.clone()
xx_no_diag.fill_diagonal_(0)
yy_no_diag.fill_diagonal_(0)
# Unbiased estimator
if n > 1:
xx_term = torch.sum(xx_no_diag) / (n * (n - 1))
else:
xx_term = torch.tensor(0.0, device=x.device)
if m > 1:
yy_term = torch.sum(yy_no_diag) / (m * (m - 1))
else:
yy_term = torch.tensor(0.0, device=y.device)
xy_term = torch.sum(xy) / (n * m)
return xx_term + yy_term - 2 * xy_term
def _kernel(self, x, y, kernel="rbf", bandwidth=1.0):
"""Compute kernel matrix.
Calculates the Gram matrix K[i,j] = k(x[i], y[j]) for the specified kernel.
Parameters
----------
x : torch.Tensor
First set of points, shape (n, dim).
y : torch.Tensor
Second set of points, shape (m, dim).
kernel : str, default='rbf'
Kernel type. Currently only 'rbf' (Gaussian) is supported.
bandwidth : float, default=1.0
RBF kernel bandwidth σ. Larger values create smoother kernels.
Returns
-------
torch.Tensor
Kernel matrix, shape (n, m).
Raises
------
ValueError
If kernel type is not 'rbf'.
Notes
-----
RBF kernel: k(x,y) = exp(-||x-y||² / (2σ²))
Bandwidth selection is important: too small may overfit, too large may
underfit. Common heuristics use median pairwise distance."""
if kernel == "rbf":
# RBF kernel
x_size = x.shape[0]
y_size = y.shape[0]
dim = x.shape[1]
x = x.unsqueeze(1) # (x_size, 1, dim)
y = y.unsqueeze(0) # (1, y_size, dim)
distances = torch.sum((x - y) ** 2, dim=2)
return torch.exp(-distances / (2 * bandwidth**2))
else:
raise ValueError(f"Unknown kernel: {kernel}")