Source code for driada.dim_reduction.flexible_ae

"""Flexible autoencoder architectures with modular loss composition.

This module provides flexible autoencoder and VAE implementations that support
dynamic loss composition for various objectives including reconstruction,
disentanglement, and regularization.
"""

from typing import List, Dict, Optional, Tuple, Any
from abc import ABC, abstractmethod
import logging

import numpy as np

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
except (ImportError, OSError):
    # OSError catches DLL loading issues on Windows
    torch = None
    nn = None
    F = None

from .neural import Encoder, Decoder, VAEEncoder
from .losses import AELoss, LossRegistry, ReconstructionLoss
from ..utils.data import check_positive


[docs] class FlexibleAutoencoderBase(nn.Module, ABC): """Abstract base class for autoencoders with flexible loss composition. Provides common infrastructure for loss management, device handling, and logging. Subclasses must implement encode(), forward(), and compute_loss() methods with their specific signatures. Parameters ---------- loss_components : list of dict, optional List of loss component configurations. Each dict should have: - "name": str, name of the loss type - "weight": float, weight for this loss - Additional parameters specific to each loss type device : torch.device, optional Device to run the model on. Defaults to CUDA if available, else CPU. logger : logging.Logger, optional Logger instance for tracking training progress. Attributes ---------- loss_registry : LossRegistry Registry for creating loss components. losses : list List of loss components (dynamically modifiable). device : torch.device Device for computations. logger : logging.Logger Logger instance. Notes ----- This base class handles: - Loss system initialization and management - Device selection and management - Logger setup - Common utilities for loss computation Subclasses define: - Encoder/decoder architecture - encode() method signature - forward() method signature - Specific loss computation logic"""
[docs] def __init__( self, loss_components: Optional[List[Dict]] = None, device: Optional[torch.device] = None, logger: Optional[logging.Logger] = None, _latent_dim: Optional[int] = None, ): """Initialize the base autoencoder infrastructure. Parameters ---------- loss_components : list of dict, optional Loss component configurations. If None, subclasses should provide defaults. device : torch.device, optional Device for computations. If None, auto-selects CUDA if available. logger : logging.Logger, optional Logger instance. If None, creates default logger. _latent_dim : int, optional Latent dimension, auto-injected into loss configs that need it.""" super().__init__() self._latent_dim_for_losses = _latent_dim # Setup device if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = device # Setup logger self.logger = logger or logging.getLogger(self.__class__.__name__) # Initialize loss system self.loss_registry = LossRegistry() self.losses = [] # List of loss components # Initialize losses if provided if loss_components is not None: self._initialize_losses(loss_components) # Register nn.Module attributes from loss components so they are # included in model.to(), state_dict(), and parameters() for i, loss_component in enumerate(self.losses): for attr_name in dir(loss_component): attr = getattr(loss_component, attr_name, None) if isinstance(attr, nn.Module): self.add_module(f"loss_{i}_{attr_name}", attr)
def _initialize_losses(self, loss_components: List[Dict]): """Initialize loss components from configuration. Parameters ---------- loss_components : list of dict Loss configurations to initialize. Raises ------ ValueError If loss component is malformed or loss name not registered.""" for i, loss_config in enumerate(loss_components): if not isinstance(loss_config, dict): raise ValueError(f"Loss component {i} must be a dict, got {type(loss_config)}") if "name" not in loss_config: raise ValueError(f"Loss component {i} missing required 'name' field") # Create loss component loss_params = loss_config.copy() name = loss_params.pop("name") weight = loss_params.pop("weight", 1.0) # Auto-inject code_dim for losses whose constructor accepts it if (self._latent_dim_for_losses is not None and "code_dim" not in loss_params and name in self.loss_registry.losses): import inspect sig = inspect.signature(self.loss_registry.losses[name].__init__) if "code_dim" in sig.parameters: loss_params["code_dim"] = self._latent_dim_for_losses loss = self.loss_registry.create(name, weight=weight, **loss_params) self.losses.append(loss) if self.logger: self.logger.debug(f"Initialized {len(self.losses)} loss components")
[docs] @abstractmethod def encode(self, x: torch.Tensor): """Encode input to latent representation. Parameters ---------- x : torch.Tensor Input data. Returns ------- Latent representation (type depends on subclass).""" pass
[docs] @abstractmethod def forward(self, x: torch.Tensor): """Forward pass through the autoencoder. Parameters ---------- x : torch.Tensor Input data. Returns ------- Output (type depends on subclass).""" pass
[docs] @abstractmethod def compute_loss( self, inputs: torch.Tensor, indices: Optional[torch.Tensor] = None, **extra_args ) -> Tuple[torch.Tensor, Dict[str, float]]: """Compute total loss from all components. Parameters ---------- inputs : torch.Tensor Input data. indices : torch.Tensor, optional Batch indices for data-dependent losses. **extra_args Additional arguments for specific losses. Returns ------- total_loss : torch.Tensor Weighted sum of all loss components. loss_dict : dict Individual loss values for logging.""" pass
def _aggregate_losses( self, loss_inputs: Dict[str, Any] ) -> Tuple[torch.Tensor, Dict[str, float]]: """Common loss aggregation logic. Parameters ---------- loss_inputs : dict Dictionary of inputs to pass to loss components. Should include keys like 'code', 'recon', 'inputs', etc. Returns ------- total_loss : torch.Tensor Weighted sum of all loss components. loss_dict : dict Individual loss values for logging. Notes ----- Subclasses can use this for common loss aggregation logic, passing appropriate inputs based on their architecture.""" total_loss = torch.tensor(0.0, device=self.device) loss_dict = {} for i, loss_component in enumerate(self.losses): loss_value = loss_component.compute(**loss_inputs) weighted_loss = loss_component.weight * loss_value total_loss = total_loss + weighted_loss # Store for logging loss_name = loss_component.__class__.__name__ loss_dict[f"{loss_name}_{i}"] = loss_value.item() loss_dict[f"{loss_name}_{i}_weighted"] = weighted_loss.item() loss_dict["total_loss"] = total_loss.item() return total_loss, loss_dict
[docs] class ModularAutoencoder(FlexibleAutoencoderBase): """Flexible autoencoder with modular loss composition. This autoencoder supports dynamic addition of loss components for various training objectives. It maintains backward compatibility while enabling advanced techniques like disentanglement and regularization. Parameters ---------- input_dim : int Dimension of input data. latent_dim : int Dimension of latent representation. hidden_dim : int, default=100 Dimension of hidden layers. encoder_config : dict, optional Configuration for encoder (e.g., dropout rate). decoder_config : dict, optional Configuration for decoder. loss_components : list of dict, optional List of loss component configurations. Each dict should have: - "name": str, name of the loss type - "weight": float, weight for this loss - Additional parameters specific to each loss type If not provided, defaults to [{"name": "reconstruction", "weight": 1.0}]. device : torch.device, optional Device to run the model on. Defaults to CUDA if available, else CPU. logger : logging.Logger, optional Logger instance for tracking training progress. Attributes ---------- encoder : Encoder Neural network encoder module. decoder : Decoder Neural network decoder module. loss_registry : LossRegistry Registry for creating loss components. losses : nn.ModuleList List of loss components (dynamically modifiable). input_dim : int Input data dimension. latent_dim : int Latent representation dimension. hidden_dim : int Hidden layer dimension. device : torch.device Device for computations. logger : logging.Logger Logger instance. Examples -------- >>> # Standard autoencoder with correlation loss >>> ae = ModularAutoencoder( ... input_dim=100, latent_dim=10, ... loss_components=[ ... {"name": "reconstruction", "weight": 1.0}, ... {"name": "correlation", "weight": 0.1} ... ] ... ) >>> # Autoencoder with sparsity and orthogonality constraints >>> import numpy as np >>> data = np.random.randn(100, 1000) # 100 features, 1000 samples >>> ae = ModularAutoencoder( ... input_dim=100, latent_dim=10, ... loss_components=[ ... {"name": "reconstruction", "weight": 1.0}, ... {"name": "sparse", "weight": 0.1, "sparsity_target": 0.05}, ... {"name": "orthogonality", "weight": 0.05, "external_data": data} ... ] ... ) Notes ----- - Loss components are stored in an nn.ModuleList for proper PyTorch integration - Default reconstruction loss is automatically added if no components specified - The get_latent_representation method transposes output for DRIADA compatibility - Loss components can be dynamically added/removed during training See Also -------- FlexibleVAE : Variational version with probabilistic encoding. ~driada.dim_reduction.losses.LossRegistry : Available loss components."""
[docs] def __init__( self, input_dim: int, latent_dim: int, hidden_dim: int = 100, encoder_config: Optional[Dict] = None, decoder_config: Optional[Dict] = None, loss_components: Optional[List[Dict]] = None, device: Optional[torch.device] = None, logger: Optional[logging.Logger] = None, ): """Initialize the modular autoencoder. Parameters ---------- input_dim : int Dimension of input data. Must be positive. latent_dim : int Dimension of latent representation. Must be positive. hidden_dim : int, default=100 Dimension of hidden layers. Must be positive. encoder_config : dict, optional Configuration for encoder (e.g., {"dropout": 0.2}). decoder_config : dict, optional Configuration for decoder. loss_components : list of dict, optional Loss component configurations. If None, uses default reconstruction loss. device : torch.device, optional Device for computations. If None, auto-selects CUDA if available. logger : logging.Logger, optional Logger instance. If None, creates default logger. Raises ------ ValueError If any dimension is not positive. If loss component name is not registered. If loss component dict is malformed.""" # Add default reconstruction loss if no components specified if not loss_components: loss_components = [{"name": "reconstruction", "weight": 1.0}] # Initialize base class with loss system (pass latent_dim for losses that need it) super().__init__( loss_components=loss_components, device=device, logger=logger, _latent_dim=latent_dim, ) # Validate dimensions check_positive(input_dim=input_dim, latent_dim=latent_dim, hidden_dim=hidden_dim) self.input_dim = input_dim self.latent_dim = latent_dim self.hidden_dim = hidden_dim # Build encoder/decoder enc_config = encoder_config or {} dec_config = decoder_config or {} self.encoder = Encoder( orig_dim=input_dim, inter_dim=hidden_dim, code_dim=latent_dim, kwargs=enc_config, device=self.device, ) self.decoder = Decoder( code_dim=latent_dim, inter_dim=hidden_dim, orig_dim=input_dim, kwargs=dec_config, device=self.device, ) if self.logger: self.logger.info( f"Initialized {self.__class__.__name__} with {len(self.losses)} loss components on {self.device}" )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the autoencoder. Parameters ---------- x : torch.Tensor Input data, shape (batch_size, input_dim). Returns ------- torch.Tensor Reconstructed data, shape (batch_size, input_dim). Notes ----- Applies encoder then decoder. Behavior affected by training/eval mode (dropout). No input validation performed.""" code = self.encoder(x) recon = self.decoder(code) return recon
[docs] def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode input to latent representation. Parameters ---------- x : torch.Tensor Input data, shape (batch_size, input_dim). Returns ------- torch.Tensor Latent code, shape (batch_size, latent_dim). Notes ----- Direct passthrough to encoder network. Applies Linear->ReLU->Dropout->Linear. Behavior affected by training/eval mode.""" return self.encoder(x)
[docs] def decode(self, z: torch.Tensor) -> torch.Tensor: """Decode latent representation to output. Parameters ---------- z : torch.Tensor Latent code, shape (batch_size, latent_dim). Returns ------- torch.Tensor Reconstructed data, shape (batch_size, input_dim). Output constrained to [0, 1] by sigmoid activation. Notes ----- Direct passthrough to decoder network. Final sigmoid activation constrains output to [0, 1] range.""" return self.decoder(z)
[docs] def compute_loss( self, inputs: torch.Tensor, indices: Optional[torch.Tensor] = None, **extra_args ) -> Tuple[torch.Tensor, Dict[str, float]]: """Compute total loss from all components. Parameters ---------- inputs : torch.Tensor Input data, shape (batch_size, input_dim). indices : torch.Tensor, optional Batch indices for data-dependent losses. **extra_args Additional arguments passed to all loss components. Returns ------- total_loss : torch.Tensor Weighted sum of all loss components. Scalar tensor on model device. loss_dict : dict Individual loss values with keys "{ClassName}_{index}" and "{ClassName}_{index}_weighted". Includes "total_loss". Notes ----- Performs encode->decode forward pass. Uses base class helper for loss aggregation. All extra_args passed to every loss component.""" # Forward pass code = self.encode(inputs) recon = self.decode(code) # Prepare loss inputs loss_inputs = { "code": code, "recon": recon, "inputs": inputs, "indices": indices, "encoder": self.encoder, **extra_args, } # Use base class aggregation return self._aggregate_losses(loss_inputs)
[docs] def get_latent_representation(self, x: torch.Tensor) -> np.ndarray: """Get latent representation for data. Parameters ---------- x : torch.Tensor Input data, shape (batch_size, input_dim). Returns ------- np.ndarray Latent representation, shape (latent_dim, batch_size). Note: Transposed for DRIADA compatibility. Notes ----- Runs in no_grad mode. Returns detached numpy array on CPU.""" with torch.no_grad(): code = self.encode(x) return code.detach().cpu().numpy().T
[docs] class FlexibleVAE(FlexibleAutoencoderBase): """Flexible Variational Autoencoder with modular loss composition. Supports variational inference with flexible loss components for advanced disentanglement techniques. Uses VAEEncoder that outputs mean and log variance parameters. Parameters ---------- input_dim : int Dimension of input data. latent_dim : int Dimension of latent representation. hidden_dim : int, default=100 Dimension of hidden layers. encoder_config : dict, optional Configuration for encoder. decoder_config : dict, optional Configuration for decoder. loss_components : list of dict, optional List of loss component configurations. Defaults to reconstruction + beta_vae losses (each with weight 1.0 and beta 1.0). device : torch.device, optional Device to run the model on. logger : logging.Logger, optional Logger instance. Notes ----- - Uses VAEEncoder with output dimension 2*latent_dim for mean and log_var - Supports deterministic mode via use_mean parameter in get_latent_representation - forward() returns (recon, mu, log_var) unlike parent's single tensor - Default includes standard VAE losses if none specified Examples -------- >>> # β-VAE for disentanglement >>> vae = FlexibleVAE( ... input_dim=100, latent_dim=10, ... loss_components=[ ... {"name": "reconstruction", "weight": 1.0}, ... {"name": "beta_vae", "weight": 1.0, "beta": 4.0} ... ] ... ) >>> # TC-VAE with decomposed ELBO >>> vae = FlexibleVAE( ... input_dim=100, latent_dim=10, ... loss_components=[ ... {"name": "reconstruction", "weight": 1.0}, ... {"name": "tc_vae", "weight": 1.0, "alpha": 1.0, "beta": 5.0, "gamma": 1.0} ... ] ... )"""
[docs] def __init__( self, input_dim: int, latent_dim: int, hidden_dim: int = 100, encoder_config: Optional[Dict] = None, decoder_config: Optional[Dict] = None, loss_components: Optional[List[Dict]] = None, device: Optional[torch.device] = None, logger: Optional[logging.Logger] = None, ): """Initialize FlexibleVAE. Parameters ---------- input_dim : int Dimension of input data. Must be positive. latent_dim : int Dimension of latent representation. Must be positive. hidden_dim : int, default=100 Dimension of hidden layers. Must be positive. encoder_config : dict, optional Configuration for VAE encoder. decoder_config : dict, optional Configuration for decoder. loss_components : list of dict, optional Loss configurations. If None, uses reconstruction + beta_vae. device : torch.device, optional Device for computations. Auto-selects CUDA if available. logger : logging.Logger, optional Logger instance. Notes ----- Uses VAEEncoder with 2*latent_dim output for mean and log variance. Default loss configuration includes both reconstruction and KL divergence.""" # VAE requires at least reconstruction loss if not loss_components: loss_components = [ {"name": "reconstruction", "weight": 1.0}, {"name": "beta_vae", "weight": 1.0, "beta": 1.0}, # Standard VAE ] # Initialize base class with loss system (pass latent_dim for losses that need it) super().__init__( loss_components=loss_components, device=device, logger=logger, _latent_dim=latent_dim, ) # Validate dimensions check_positive(input_dim=input_dim, latent_dim=latent_dim, hidden_dim=hidden_dim) self.input_dim = input_dim self.latent_dim = latent_dim self.hidden_dim = hidden_dim # Build VAE-specific encoder enc_config = encoder_config or {} dec_config = decoder_config or {} # VAE encoder outputs 2*latent_dim (mean and log_var) self.encoder = VAEEncoder( orig_dim=input_dim, inter_dim=hidden_dim, code_dim=2 * latent_dim, kwargs=enc_config, device=self.device, ) self.decoder = Decoder( code_dim=latent_dim, inter_dim=hidden_dim, orig_dim=input_dim, kwargs=dec_config, device=self.device, ) if self.logger: self.logger.info( f"Initialized {self.__class__.__name__} with {len(self.losses)} loss components on {self.device}" )
[docs] def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: """Reparameterization trick for VAE. Parameters ---------- mu : torch.Tensor Mean of latent distribution, shape (batch_size, latent_dim). log_var : torch.Tensor Log variance of latent distribution, shape (batch_size, latent_dim). Returns ------- torch.Tensor Sampled latent code, shape (batch_size, latent_dim). Notes ----- Always samples stochastically. No numerical stability checks for extreme log_var values. Use get_latent_representation(use_mean=True) for deterministic behavior.""" std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return mu + eps * std
[docs] def decode(self, z: torch.Tensor) -> torch.Tensor: """Decode latent representation to output. Parameters ---------- z : torch.Tensor Latent code, shape (batch_size, latent_dim). Returns ------- torch.Tensor Reconstructed data, shape (batch_size, input_dim). Output constrained to [0, 1] by sigmoid activation. Notes ----- Direct passthrough to decoder network. Final sigmoid activation constrains output to [0, 1] range.""" return self.decoder(z)
[docs] def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Encode input to latent distribution parameters. Parameters ---------- x : torch.Tensor Input data, shape (batch_size, input_dim). Returns ------- z : torch.Tensor Sampled latent code, shape (batch_size, latent_dim). mu : torch.Tensor Mean of latent distribution, shape (batch_size, latent_dim). log_var : torch.Tensor Log variance of latent distribution, shape (batch_size, latent_dim). Notes ----- Encoder must output exactly 2*latent_dim features. First half is interpreted as mean, second half as log variance. Always samples via reparameterization.""" # Get distribution parameters h = self.encoder(x) h = h.view(-1, 2, self.latent_dim) mu = h[:, 0, :] log_var = h[:, 1, :] # Sample latent code z = self.reparameterize(mu, log_var) return z, mu, log_var
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass through the VAE. Parameters ---------- x : torch.Tensor Input data, shape (batch_size, input_dim). Returns ------- recon : torch.Tensor Reconstructed data, shape (batch_size, input_dim). mu : torch.Tensor Mean of latent distribution, shape (batch_size, latent_dim). log_var : torch.Tensor Log variance of latent distribution, shape (batch_size, latent_dim). Notes ----- Returns tuple unlike parent's single tensor. Always uses sampled z for reconstruction, not the mean.""" z, mu, log_var = self.encode(x) recon = self.decode(z) return recon, mu, log_var
[docs] def compute_loss( self, inputs: torch.Tensor, indices: Optional[torch.Tensor] = None, **extra_args ) -> Tuple[torch.Tensor, Dict[str, float]]: """Compute total loss from all components. Parameters ---------- inputs : torch.Tensor Input data, shape (batch_size, input_dim). indices : torch.Tensor, optional Batch indices for data-dependent losses. **extra_args Additional arguments for specific losses. Returns ------- total_loss : torch.Tensor Weighted sum of all loss components. loss_dict : dict Individual loss values for logging. Notes ----- Performs single forward pass. Passes mu and log_var to all loss components. Uses base class aggregation helper.""" # Forward pass - single encoding z, mu, log_var = self.encode(inputs) recon = self.decode(z) # Prepare loss inputs - includes VAE-specific parameters loss_inputs = { "code": z, "recon": recon, "inputs": inputs, "mu": mu, "log_var": log_var, "indices": indices, "encoder": self.encoder, **extra_args, } # Use base class aggregation return self._aggregate_losses(loss_inputs)
[docs] def get_latent_representation(self, x: torch.Tensor, use_mean: bool = True) -> np.ndarray: """Get latent representation for data. Parameters ---------- x : torch.Tensor Input data, shape (batch_size, input_dim). use_mean : bool, default=True If True, return mean of latent distribution (deterministic). If False, return sampled latent code (stochastic). Returns ------- np.ndarray Latent representation, shape (latent_dim, batch_size). Transposed for DRIADA compatibility. Notes ----- Default behavior is deterministic (use_mean=True) for reproducible embeddings. Set use_mean=False to capture uncertainty via sampling.""" with torch.no_grad(): z, mu, log_var = self.encode(x) if use_mean: return mu.detach().cpu().numpy().T else: return z.detach().cpu().numpy().T