Source code for driada.utils.visual

"""
Visualization utilities for DRIADA
==================================

This module provides reusable visualization functions for embedding comparisons,
trajectory plots, and component interpretation in dimensionality reduction analyses.
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from typing import Dict, List, Tuple, Optional, Union
from scipy.stats import gaussian_kde

# Import validation utilities
from driada.utils.data import check_positive, check_nonnegative

# Default DPI for all plots
DEFAULT_DPI = 150


[docs] def plot_embedding_comparison( embeddings: Dict[str, np.ndarray], features: Optional[Dict[str, np.ndarray]] = None, feature_names: Optional[Dict[str, str]] = None, methods: Optional[List[str]] = None, with_trajectory: bool = True, compute_metrics: bool = True, trajectory_kwargs: Optional[Dict] = None, figsize: Optional[Tuple[float, float]] = None, scatter_size: float = 2, save_path: Optional[str] = None, dpi: int = DEFAULT_DPI, ) -> plt.Figure: """ Create comprehensive embedding comparison figure with behavioral features and trajectories. Parameters ---------- embeddings : dict Dictionary mapping method names to embedding arrays (n_samples, n_components). Arrays must be 2D with at least 2 components. features : dict, optional Dictionary mapping feature names to feature arrays. Arrays must have same length as embeddings. Features with 'angle' or 'direction' in the name are treated as circular and plotted with the 'hsv' colormap; all others use 'viridis'. The first two features are shown as rows 1 and 2. feature_names : dict, optional Dictionary mapping feature keys to display names methods : list of str, optional List of methods to plot (if None, uses all keys in embeddings) with_trajectory : bool, default True Whether to include trajectory visualization as a third row compute_metrics : bool, default True Whether to compute and display metrics (density contours, percentiles) trajectory_kwargs : dict, optional Additional keyword arguments for trajectory plotting figsize : tuple, optional Figure size (width, height). If None, computed based on number of methods scatter_size : float, default 2 Marker size for scatter points save_path : str, optional Path to save the figure dpi : int, default DEFAULT_DPI DPI resolution for saved figure Returns ------- fig : matplotlib.figure.Figure The generated figure Raises ------ ValueError If embeddings are not 2D arrays with at least 2 components, or if feature arrays have mismatched lengths Notes ----- Methods not found in embeddings dict are silently skipped. KDE computation failures are caught and contours are omitted.""" # Validate embeddings for method, embedding in embeddings.items(): if embedding.ndim != 2: raise ValueError(f"Embedding for {method} must be 2D, got shape {embedding.shape}") if embedding.shape[1] < 2: raise ValueError( f"Embedding for {method} must have at least 2 components, got {embedding.shape[1]}" ) if methods is None: methods = list(embeddings.keys()) n_methods = len(methods) n_rows = 3 if with_trajectory else 2 # Set figure size if figsize is None: figsize = (6 * n_methods, 5 * n_rows) # Create figure and grid fig = plt.figure(figsize=figsize) gs = gridspec.GridSpec(n_rows, n_methods, hspace=0.3, wspace=0.3) # Default feature names if feature_names is None: feature_names = {"angle": "Head Direction", "speed": "Speed"} # Default trajectory kwargs if trajectory_kwargs is None: trajectory_kwargs = {} default_traj_kwargs = { "linewidth": 0.8, "alpha": 0.3, "color": "k", "arrow_spacing": 20, "arrow_scale": 0.3, "start_marker": "o", "end_marker": "s", "marker_size": 100, } default_traj_kwargs.update(trajectory_kwargs) for i, method in enumerate(methods): if method not in embeddings: continue embedding = embeddings[method] # Validate features match embedding length if features is not None: for feat_name, feat_array in features.items(): if len(feat_array) != len(embedding): raise ValueError( f"Feature '{feat_name}' length ({len(feat_array)}) " f"doesn't match embedding length ({len(embedding)})" ) # Determine which features to plot (first two from features dict) feat_keys = list(features.keys()) if features else [] _CIRCULAR_HINTS = ("angle", "direction", "heading", "phase") for row_idx in range(min(2, len(feat_keys))): ax = fig.add_subplot(gs[row_idx, i]) feat_key = feat_keys[row_idx] feat_vals = features[feat_key] display_name = feature_names.get(feat_key, feat_key) if feature_names else feat_key is_circular = any(h in feat_key.lower() for h in _CIRCULAR_HINTS) if is_circular: vals_norm = (feat_vals - feat_vals.min()) / (feat_vals.max() - feat_vals.min() + 1e-12) scatter = ax.scatter( embedding[:, 0], embedding[:, 1], c=vals_norm, cmap="hsv", alpha=0.7, s=scatter_size, vmin=0, vmax=1, edgecolors="none", ) plt.colorbar(scatter, ax=ax, label=display_name) if compute_metrics: try: kde = gaussian_kde(embedding[:, :2].T) x_min, x_max = ax.get_xlim() y_min, y_max = ax.get_ylim() X, Y = np.mgrid[x_min:x_max:50j, y_min:y_max:50j] positions_grid = np.vstack([X.ravel(), Y.ravel()]) Z = np.reshape(kde(positions_grid).T, X.shape) ax.contour(X, Y, Z, colors="gray", alpha=0.3, linewidths=0.5) except (ValueError, RuntimeError, np.linalg.LinAlgError): pass else: scatter = ax.scatter( embedding[:, 0], embedding[:, 1], c=feat_vals, cmap="viridis", alpha=0.7, s=scatter_size, edgecolors="none", ) cbar = plt.colorbar(scatter, ax=ax, label=display_name) if compute_metrics: percentiles = np.percentile(feat_vals, [25, 50, 75]) for p, val in zip([25, 50, 75], percentiles): cbar.ax.axhline(y=val, color="red", alpha=0.3, linewidth=0.5) cbar.ax.text( 1.05, val, f"{p}%", transform=cbar.ax.get_yaxis_transform(), fontsize=8, va="center", ) ax.set_xlabel("Component 0") ax.set_ylabel("Component 1") ax.set_title(f"{method.upper()} - {display_name}") ax.grid(True, alpha=0.3) # Fill remaining rows if fewer than 2 features for row_idx in range(len(feat_keys), 2): ax = fig.add_subplot(gs[row_idx, i]) ax.scatter(embedding[:, 0], embedding[:, 1], alpha=0.6, s=scatter_size) ax.set_xlabel("Component 0") ax.set_ylabel("Component 1") ax.set_title(f"{method.upper()}") ax.grid(True, alpha=0.3) # Third row: trajectory visualization if with_trajectory: ax3 = fig.add_subplot(gs[2, i]) # Plot trajectory ax3.plot( embedding[:, 0], embedding[:, 1], color=default_traj_kwargs["color"], alpha=default_traj_kwargs["alpha"], linewidth=default_traj_kwargs["linewidth"], ) # Add arrow markers to show direction trajectory_samples = len(embedding) arrow_spacing = max(1, trajectory_samples // default_traj_kwargs["arrow_spacing"]) for j in range(0, trajectory_samples - arrow_spacing, arrow_spacing): if j + 1 < trajectory_samples: dx = embedding[j + 1, 0] - embedding[j, 0] dy = embedding[j + 1, 1] - embedding[j, 1] # Only plot arrow if movement is significant if np.sqrt(dx**2 + dy**2) > 0.001: ax3.arrow( embedding[j, 0], embedding[j, 1], dx * default_traj_kwargs["arrow_scale"], dy * default_traj_kwargs["arrow_scale"], head_width=0.02, head_length=0.02, fc="red", ec="red", alpha=0.6, ) # Mark start and end points ax3.scatter( embedding[0, 0], embedding[0, 1], c="green", s=default_traj_kwargs["marker_size"], marker=default_traj_kwargs["start_marker"], edgecolors="black", linewidth=2, label="Start", zorder=5, ) ax3.scatter( embedding[-1, 0], embedding[-1, 1], c="red", s=default_traj_kwargs["marker_size"], marker=default_traj_kwargs["end_marker"], edgecolors="black", linewidth=2, label="End", zorder=5, ) ax3.set_xlabel("Component 0") ax3.set_ylabel("Component 1") ax3.set_title(f"{method.upper()} - Trajectory") ax3.grid(True, alpha=0.3) ax3.legend(loc="best", fontsize=8) ax3.set_aspect("equal", adjustable="datalim") # Set main title title = "Population Embeddings: Behavioral Features" if with_trajectory: title += " and Trajectories" plt.suptitle(title, fontsize=16) # Save if requested if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches="tight") return fig
[docs] def plot_trajectories( embeddings: Dict[str, np.ndarray], methods: Optional[List[str]] = None, trajectory_kwargs: Optional[Dict] = None, figsize: Optional[Tuple[float, float]] = None, save_path: Optional[str] = None, dpi: int = DEFAULT_DPI, ) -> plt.Figure: """ Create figure showing trajectories in embedding space for multiple methods. Parameters ---------- embeddings : dict Dictionary mapping method names to embedding arrays. Arrays must be 2D with at least 2 components. methods : list of str, optional List of methods to plot (if None, uses all keys in embeddings) trajectory_kwargs : dict, optional Keyword arguments for trajectory plotting figsize : tuple, optional Figure size save_path : str, optional Path to save the figure dpi : int, default DEFAULT_DPI DPI resolution for saved figure Returns ------- fig : matplotlib.figure.Figure The generated figure Raises ------ ValueError If embeddings are not 2D arrays with at least 2 components""" # Validate embeddings for method, embedding in embeddings.items(): if embedding.ndim != 2: raise ValueError(f"Embedding for {method} must be 2D, got shape {embedding.shape}") if embedding.shape[1] < 2: raise ValueError( f"Embedding for {method} must have at least 2 components, got {embedding.shape[1]}" ) if len(embedding) == 0: raise ValueError(f"Embedding for {method} is empty") if methods is None: methods = list(embeddings.keys()) n_methods = len(methods) if figsize is None: figsize = (6 * n_methods, 5) fig = plt.figure(figsize=figsize) # Default trajectory kwargs if trajectory_kwargs is None: trajectory_kwargs = {} default_kwargs = { "linewidth": 0.8, "alpha": 0.3, "color": "k", "arrow_spacing": 20, "arrow_scale": 0.3, "start_marker": "o", "end_marker": "s", "marker_size": 100, } default_kwargs.update(trajectory_kwargs) for i, method in enumerate(methods): if method not in embeddings: continue embedding = embeddings[method] ax = fig.add_subplot(1, n_methods, i + 1) # Plot trajectory ax.plot( embedding[:, 0], embedding[:, 1], color=default_kwargs["color"], alpha=default_kwargs["alpha"], linewidth=default_kwargs["linewidth"], ) # Add direction arrows trajectory_samples = len(embedding) arrow_spacing = max(1, trajectory_samples // default_kwargs["arrow_spacing"]) for j in range(0, trajectory_samples - arrow_spacing, arrow_spacing): if j + 1 < trajectory_samples: dx = embedding[j + 1, 0] - embedding[j, 0] dy = embedding[j + 1, 1] - embedding[j, 1] if np.sqrt(dx**2 + dy**2) > 0.001: # Scale arrow size with data range x_range = embedding[:, 0].max() - embedding[:, 0].min() y_range = embedding[:, 1].max() - embedding[:, 1].min() arrow_scale = min(x_range, y_range) * 0.01 ax.arrow( embedding[j, 0], embedding[j, 1], dx * default_kwargs["arrow_scale"], dy * default_kwargs["arrow_scale"], head_width=arrow_scale, head_length=arrow_scale, fc="red", ec="red", alpha=0.6, ) # Mark start and end ax.scatter( embedding[0, 0], embedding[0, 1], c="green", s=default_kwargs["marker_size"], marker=default_kwargs["start_marker"], edgecolors="black", linewidth=2, label="Start", zorder=5, ) ax.scatter( embedding[-1, 0], embedding[-1, 1], c="red", s=default_kwargs["marker_size"], marker=default_kwargs["end_marker"], edgecolors="black", linewidth=2, label="End", zorder=5, ) ax.set_xlabel("Component 0") ax.set_ylabel("Component 1") ax.set_title(f"{method.upper()} - Trajectory") ax.grid(True, alpha=0.3) ax.legend(loc="best", fontsize=8) ax.set_aspect("equal", adjustable="datalim") plt.suptitle("Temporal Trajectories in Embedding Space", fontsize=16) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches="tight") return fig
[docs] def plot_component_interpretation( mi_matrices: Dict[str, np.ndarray], feature_names: List[str], methods: Optional[List[str]] = None, n_components: Optional[int] = None, metadata: Optional[Dict[str, Dict]] = None, compute_metrics: bool = True, figsize: Optional[Tuple[float, float]] = None, save_path: Optional[str] = None, dpi: int = DEFAULT_DPI, ) -> plt.Figure: """ Create figure showing mutual information between embedding components and features. Parameters ---------- mi_matrices : dict Dictionary mapping method names to MI matrices (n_features, n_components). MI values should be non-negative. feature_names : list of str Names of features for y-axis labels methods : list of str, optional List of methods to plot (if None, uses all keys in mi_matrices) n_components : int, optional Number of components to show (default: min 5 or available) metadata : dict, optional Dictionary of metadata for each method (e.g., explained variance for PCA) compute_metrics : bool, default True Whether to show additional metrics (e.g., explained variance) figsize : tuple, optional Figure size save_path : str, optional Path to save the figure dpi : int, default DEFAULT_DPI DPI resolution for saved figure Returns ------- fig : matplotlib.figure.Figure The generated figure Raises ------ ValueError If MI matrices are not 2D or contain negative values""" # Validate MI matrices for method, mi_matrix in mi_matrices.items(): if mi_matrix.ndim != 2: raise ValueError(f"MI matrix for {method} must be 2D, got shape {mi_matrix.shape}") if len(feature_names) != mi_matrix.shape[0]: raise ValueError( f"MI matrix for {method} has {mi_matrix.shape[0]} features but " f"feature_names has {len(feature_names)} names" ) if np.any(mi_matrix < 0): raise ValueError(f"MI matrix for {method} contains negative values") if methods is None: methods = list(mi_matrices.keys()) n_methods = len(methods) if figsize is None: figsize = (8 * n_methods, 6) fig = plt.figure(figsize=figsize) for idx, method in enumerate(methods): if method not in mi_matrices: continue mi_matrix = mi_matrices[method] # Determine number of components to show n_comp_available = mi_matrix.shape[1] if n_components is None: n_comp_show = min(5, n_comp_available) else: n_comp_show = min(n_components, n_comp_available) # Create subplot ax = plt.subplot(1, n_methods, idx + 1) # Plot MI heatmap mi_subset = mi_matrix[:, :n_comp_show] max_mi = np.max(mi_subset) if np.max(mi_subset) > 0 else 1 im = ax.imshow(mi_subset, aspect="auto", cmap="YlOrRd", vmin=0, vmax=max_mi) # Set labels ax.set_xticks(range(n_comp_show)) # Create component labels based on method if method.lower() == "pca": comp_labels = [f"PC{i}" for i in range(n_comp_show)] elif method.lower() == "umap": comp_labels = [f"UMAP{i}" for i in range(n_comp_show)] elif method.lower() == "le": comp_labels = [f"LE{i}" for i in range(n_comp_show)] else: comp_labels = [f"{method.upper()}{i}" for i in range(n_comp_show)] ax.set_xticklabels(comp_labels) ax.set_yticks(range(len(feature_names))) ax.set_yticklabels(feature_names) ax.set_xlabel(f"{method.upper()} Components") ax.set_title(f"{method.upper()} Component-Feature MI") # Add MI values on cells for i in range(len(feature_names)): for j in range(n_comp_show): text_color = "black" if mi_subset[i, j] < max_mi * 0.5 else "white" ax.text( j, i, f"{mi_subset[i, j]:.3f}", ha="center", va="center", color=text_color, fontsize=9, ) # Add colorbar cbar = plt.colorbar(im, ax=ax, label="Mean MI (bits)") # Add method-specific metrics if available if compute_metrics and metadata is not None and method in metadata: method_meta = metadata[method] # For PCA, show explained variance if method.lower() == "pca" and "explained_variance_ratio" in method_meta: var_exp = method_meta["explained_variance_ratio"][:n_comp_show] var_text = "Var explained: " + ", ".join([f"{v*100:.1f}%" for v in var_exp]) ax.text( 0.5, -0.15, var_text, transform=ax.transAxes, ha="center", va="top", fontsize=8, style="italic", ) plt.suptitle( "Component Interpretation: Mutual Information between Components and Features", fontsize=16, ) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches="tight") return fig
[docs] def plot_embeddings_grid( embeddings: Dict[str, Dict[str, np.ndarray]], labels: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] = None, methods: Optional[List[str]] = None, scenarios: Optional[List[str]] = None, metrics: Optional[Dict[str, Dict[str, Dict[str, float]]]] = None, colormap: str = "viridis", figsize: Optional[Tuple[float, float]] = None, n_cols: int = 4, save_path: Optional[str] = None, dpi: int = DEFAULT_DPI, ) -> Optional[plt.Figure]: """ Create grid of embeddings for multiple methods and scenarios. Parameters ---------- embeddings : dict of dict Nested dictionary: {method: {scenario: embedding_array}}. Arrays must be 2D with at least 2 components. labels : array or dict, optional Color labels for points. Can be array (same for all) or dict matching structure methods : list, optional Methods to plot (default: all in embeddings) scenarios : list, optional Scenarios to plot (default: all available) metrics : dict, optional Nested dict of metrics: {method: {scenario: {metric_name: value}}}. At most 2 metrics shown per subplot. colormap : str Colormap for scatter plots figsize : tuple, optional Figure size n_cols : int Number of columns in grid save_path : str, optional Path to save figure dpi : int, default DEFAULT_DPI DPI resolution for saved figure Returns ------- fig : matplotlib.figure.Figure or None The generated figure, or None if no valid embeddings to plot Raises ------ ValueError If embeddings are not 2D or label lengths mismatch""" if methods is None: methods = list(embeddings.keys()) # Collect all scenario-method pairs all_plots = [] for method in methods: if method not in embeddings: continue if scenarios is None: method_scenarios = list(embeddings[method].keys()) else: method_scenarios = [s for s in scenarios if s in embeddings[method]] for scenario in method_scenarios: if embeddings[method][scenario] is not None: embedding = embeddings[method][scenario] if embedding.ndim != 2 or embedding.shape[1] < 2: raise ValueError( f"Embedding for {method}/{scenario} must be 2D with " f"at least 2 components, got shape {embedding.shape}" ) all_plots.append((method, scenario)) if not all_plots: return None # Calculate grid dimensions n_plots = len(all_plots) n_rows = (n_plots + n_cols - 1) // n_cols if figsize is None: figsize = (4 * n_cols, 4 * n_rows) fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) if n_rows == 1: axes = axes.reshape(1, -1) elif n_cols == 1: axes = axes.reshape(-1, 1) # Plot each embedding for idx, (method, scenario) in enumerate(all_plots): row = idx // n_cols col = idx % n_cols ax = axes[row, col] embedding = embeddings[method][scenario] # Get labels for coloring if labels is None: color_labels = np.arange(len(embedding)) elif isinstance(labels, dict): if method in labels and scenario in labels[method]: color_labels = labels[method][scenario] if len(color_labels) != len(embedding): raise ValueError( f"Labels for {method}/{scenario} have length {len(color_labels)} " f"but embedding has {len(embedding)} samples" ) else: color_labels = np.arange(len(embedding)) else: color_labels = labels if len(color_labels) != len(embedding): raise ValueError( f"Labels have length {len(color_labels)} but embedding " f"for {method}/{scenario} has {len(embedding)} samples" ) # Create scatter plot scatter = ax.scatter( embedding[:, 0], embedding[:, 1], c=color_labels, cmap=colormap, s=10, alpha=0.7, edgecolors="none", ) # Add title with metrics if available title = f"{method} - {scenario}" if metrics and method in metrics and scenario in metrics[method]: metric_strs = [] for metric_name, value in metrics[method][scenario].items(): if isinstance(value, float): metric_strs.append(f"{metric_name}: {value:.3f}") if metric_strs: title += "\n" + ", ".join(metric_strs[:2]) # Show max 2 metrics ax.set_title(title, fontsize=10) ax.set_xlabel("Component 0") ax.set_ylabel("Component 1") ax.grid(True, alpha=0.3) # Hide unused subplots for idx in range(n_plots, n_rows * n_cols): row = idx // n_cols col = idx % n_cols axes[row, col].set_visible(False) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches="tight") return fig
# Note: plot_quality_metrics_comparison, plot_quality_vs_speed_tradeoff # are removed as they are too specific to certain examples and not reused elsewhere
[docs] def plot_neuron_selectivity_summary( selectivity_counts: Dict[str, int], total_neurons: int, colors: Optional[Dict[str, str]] = None, figsize: Tuple[float, float] = (8, 6), save_path: Optional[str] = None, dpi: int = DEFAULT_DPI, ) -> plt.Figure: """ Create bar plot summarizing neuron selectivity categories. Parameters ---------- selectivity_counts : dict Dictionary mapping category names to counts. Counts should be non-negative integers with sum <= total_neurons. total_neurons : int Total number of neurons. Must be positive. colors : dict, optional Dictionary mapping category names to colors figsize : tuple Figure size save_path : str, optional Path to save figure dpi : int, default DEFAULT_DPI DPI resolution for saved figure Returns ------- fig : matplotlib.figure.Figure Raises ------ ValueError If total_neurons <= 0 or counts are invalid""" # Validate inputs using utility functions check_positive(total_neurons=total_neurons) # Validate counts for category, count in selectivity_counts.items(): if not isinstance(count, (int, np.integer)): raise ValueError( f"Count for category '{category}' must be an integer, got {type(count).__name__}" ) # Check all counts are non-negative check_nonnegative(**selectivity_counts) # Check sum doesn't exceed total total_count = sum(selectivity_counts.values()) if total_count > total_neurons: raise ValueError(f"Sum of counts ({total_count}) exceeds total_neurons ({total_neurons})") if colors is None: # Default colors for common categories colors = { "Spatial": "darkgreen", "spatial": "darkgreen", "position_2d": "darkgreen", "x_position": "green", "y_position": "lightgreen", "head_direction": "blue", "speed": "orange", "task_type": "red", "reward": "purple", "Non-spatial": "gray", "non_spatial": "gray", "Non-selective": "lightgray", } fig, ax = plt.subplots(figsize=figsize) categories = list(selectivity_counts.keys()) counts = list(selectivity_counts.values()) # Get colors for each category bar_colors = [colors.get(cat, "steelblue") for cat in categories] # Create bars bars = ax.bar(categories, counts, color=bar_colors, alpha=0.7) # Add percentage labels for bar, count in zip(bars, counts): percentage = count / total_neurons * 100 ax.text( bar.get_x() + bar.get_width() / 2.0, bar.get_height() + 1, f"{percentage:.1f}%", ha="center", va="bottom", ) ax.set_ylabel("Number of neurons") ax.set_title("Neuron Selectivity Categories") ax.set_ylim(0, max(counts) * 1.15) # Add total count as text ax.text( 0.02, 0.98, f"Total neurons: {total_neurons}", transform=ax.transAxes, ha="left", va="top", bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), ) plt.xticks(rotation=45, ha="right") plt.tight_layout() if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches="tight") return fig
[docs] def plot_component_selectivity_heatmap( selectivity_matrix: np.ndarray, methods: List[str], n_components_per_method: Optional[Dict[str, int]] = None, figsize: Optional[Tuple[float, float]] = None, save_path: Optional[str] = None, dpi: int = DEFAULT_DPI, ) -> plt.Figure: """ Create heatmap showing neuron selectivity to embedding components. Parameters ---------- selectivity_matrix : ndarray Matrix of shape (n_neurons, total_components) with MI values. Must be 2D with non-negative values. methods : list of str List of DR method names. Cannot be empty. n_components_per_method : dict, optional Number of components for each method. If None, assumes equal distribution across methods. figsize : tuple, optional Figure size (width, height) save_path : str, optional Path to save figure dpi : int, default DEFAULT_DPI DPI resolution for saved figure Returns ------- fig : matplotlib.figure.Figure The generated figure Raises ------ ValueError If selectivity_matrix is not 2D, contains negative values, methods list is empty, or component counts don't match matrix""" # Validate inputs if selectivity_matrix.ndim != 2: raise ValueError(f"selectivity_matrix must be 2D, got shape {selectivity_matrix.shape}") if len(methods) == 0: raise ValueError("methods list cannot be empty") # Check for non-negative values if np.any(selectivity_matrix < 0): raise ValueError("selectivity_matrix cannot contain negative values") # Check for NaN or inf if np.any(~np.isfinite(selectivity_matrix)): raise ValueError("selectivity_matrix contains NaN or infinite values") n_neurons, total_components = selectivity_matrix.shape if n_components_per_method is None: # Assume equal components per method n_methods = len(methods) if total_components % n_methods != 0: raise ValueError( f"Total components ({total_components}) not evenly divisible by number of methods ({n_methods})" ) n_comp_each = total_components // n_methods n_components_per_method = {m: n_comp_each for m in methods} else: # Validate component counts match matrix total_specified = sum(n_components_per_method.values()) if total_specified != total_components: raise ValueError( f"Sum of components ({total_specified}) doesn't match matrix columns ({total_components})" ) # Ensure all methods have component counts for method in methods: if method not in n_components_per_method: raise ValueError(f"Missing component count for method '{method}'") check_positive(**{f"{method}_components": n_components_per_method[method]}) if figsize is None: figsize = (5 * len(methods), 8) fig, axes = plt.subplots(1, len(methods), figsize=figsize) if len(methods) == 1: axes = [axes] comp_start = 0 for ax, method in zip(axes, methods): n_comp = n_components_per_method[method] # Extract subset for this method method_matrix = selectivity_matrix[:, comp_start : comp_start + n_comp] # Plot heatmap im = ax.imshow(method_matrix.T, aspect="auto", cmap="hot", interpolation="nearest") ax.set_xlabel("Neuron ID") ax.set_ylabel("Component") ax.set_title(f"{method.upper()} Component Selectivity") # Add colorbar cbar = plt.colorbar(im, ax=ax) cbar.set_label("Mutual Information (bits)") # Set component labels ax.set_yticks(range(n_comp)) ax.set_yticklabels([f"C{i}" for i in range(n_comp)]) comp_start += n_comp plt.suptitle("Neuron Selectivity to Embedding Components", fontsize=14) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches="tight") return fig
[docs] def compute_circular_coordinates(embedding: np.ndarray) -> np.ndarray: """ Convert 2D embedding to circular coordinates (angles). Parameters ---------- embedding : ndarray 2D embedding array of shape (n_samples, 2) Returns ------- angles : ndarray Angles in radians [0, 2pi] of shape (n_samples,) """ centered = embedding - np.mean(embedding, axis=0) angles = np.arctan2(centered[:, 1], centered[:, 0]) angles = np.mod(angles, 2 * np.pi) return angles
[docs] def visualize_circular_manifold( embeddings: List[np.ndarray], true_angles: np.ndarray, method_names: List[str], save_path: Optional[str] = None, dpi: int = DEFAULT_DPI, ) -> plt.Figure: """ Visualize circular manifold extraction from different DR methods. Creates a figure with two rows: - Top row: 2D embeddings colored by true head direction - Bottom row: True vs reconstructed angle scatter plots Parameters ---------- embeddings : list of ndarray List of 2D embedding arrays, each of shape (n_samples, 2) true_angles : ndarray Ground truth angles in radians, shape (n_samples,) method_names : list of str Names of DR methods for plot titles save_path : str, optional Path to save figure dpi : int, default DEFAULT_DPI DPI resolution for saved figure Returns ------- fig : matplotlib.figure.Figure The generated figure """ from driada.dim_reduction.manifold_metrics import compute_embedding_alignment_metrics n_methods = len(embeddings) fig, axes = plt.subplots(2, n_methods, figsize=(5 * n_methods, 10)) if n_methods == 1: axes = axes.reshape(2, 1) for i, (embedding, method) in enumerate(zip(embeddings, method_names)): # Plot 2D embedding colored by true angle ax = axes[0, i] scatter = ax.scatter( embedding[:, 0], embedding[:, 1], c=true_angles, cmap="hsv", s=20, alpha=0.7 ) ax.set_title(f"{method} Embedding") ax.set_xlabel("Component 1") ax.set_ylabel("Component 2") if i == n_methods - 1: cbar = plt.colorbar(scatter, ax=ax) cbar.set_label("True head direction (rad)") # Get optimal alignment using manifold metrics alignment_metrics = compute_embedding_alignment_metrics( embedding, true_angles, "circular" ) error = alignment_metrics["error"] correlation = alignment_metrics["correlation"] rotation_offset = alignment_metrics["rotation_offset"] is_reflected = alignment_metrics["is_reflected"] # Extract and transform angles recon_angles = compute_circular_coordinates(embedding) if is_reflected: recon_angles = -recon_angles recon_angles = recon_angles + rotation_offset # Plot true vs reconstructed angles ax = axes[1, i] true_wrapped = np.mod(true_angles, 2 * np.pi) recon_wrapped = np.mod(recon_angles, 2 * np.pi) # Handle wraparound by plotting boundary points twice threshold = 0.5 near_zero_true = true_wrapped < threshold near_2pi_true = true_wrapped > (2 * np.pi - threshold) near_zero_recon = recon_wrapped < threshold near_2pi_recon = recon_wrapped > (2 * np.pi - threshold) ax.scatter(true_wrapped, recon_wrapped, alpha=0.5, s=10, color="blue") # Wrapped copies for continuity for mask, t_offset, r_offset in [ (near_zero_true & near_2pi_recon, 0, -2 * np.pi), (near_2pi_true & near_zero_recon, 0, 2 * np.pi), (near_zero_true & near_zero_recon, 2 * np.pi, 2 * np.pi), (near_2pi_true & near_2pi_recon, -2 * np.pi, -2 * np.pi), ]: if np.any(mask): ax.scatter( true_wrapped[mask] + t_offset, recon_wrapped[mask] + r_offset, alpha=0.5, s=10, color="blue" ) ax.plot([0, 2 * np.pi], [0, 2 * np.pi], "r--", alpha=0.5) ax.set_xlabel("True angle (rad)") ax.set_ylabel("Reconstructed angle (rad)") ax.set_title(f"r = {correlation:.3f}, error = {error:.3f} rad") ax.set_xlim([-0.5, 2 * np.pi + 0.5]) ax.set_ylim([-0.5, 2 * np.pi + 0.5]) for val in [0, 2 * np.pi]: ax.axvline(val, color="gray", alpha=0.3, linestyle=":") ax.axhline(val, color="gray", alpha=0.3, linestyle=":") plt.tight_layout() if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches="tight") return fig