Source code for driada.intense.visual

import numpy as np

import matplotlib.pyplot as plt
from ..utils.plot import create_default_figure, make_beautiful
from ..utils.data import rescale
from scipy.stats import rankdata, gaussian_kde, wasserstein_distance
import seaborn as sns


[docs] def plot_pc_activity( exp, cell_ind, place_key=("x", "y"), ds=5, ax=None, show_trajectory=False, show_spikes=True, cmap="plasma", marker_size=100, marker_style="*", marker_color="k", scatter_alpha=0.8, trajectory_alpha=0.3, trajectory_color="gray", figsize_base=6, title_format="Cell {cell_ind}, Rel MI={rel_mi:.4f}, pval={pval:.2e}", xlabel=None, ylabel=None, show_stats=True, ): """ Plot place cell activity overlaid on spatial trajectory. Parameters ---------- exp : Experiment Experiment object with spatial data and neurons. cell_ind : int Index of the neuron to plot. place_key : tuple or str, optional Feature key for spatial data. Default: ("x", "y"). Can be tuple like ("x", "y") or string like "position". ds : int, optional Downsampling factor. Default: 5. ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. show_trajectory : bool, optional Whether to show trajectory line. Default: False. show_spikes : bool, optional Whether to show spike markers. Default: True. cmap : str, optional Colormap for activity. Default: "plasma". marker_size : int, optional Size of spike markers. Default: 100. marker_style : str, optional Marker style for spikes. Default: "*". marker_color : str, optional Color for spike markers. Default: "k". scatter_alpha : float, optional Alpha for activity scatter. Default: 0.8. trajectory_alpha : float, optional Alpha for trajectory line. Default: 0.3. trajectory_color : str, optional Color for trajectory. Default: "gray". figsize_base : float, optional Base figure size (adjusted by aspect ratio). Default: 6. title_format : str, optional Format string for title. Available keys: cell_ind, rel_mi, pval. Default: "Cell {cell_ind}, Rel MI={rel_mi:.4f}, pval={pval:.2e}" xlabel : str, optional X-axis label. Default: first element of place_key or "x". ylabel : str, optional Y-axis label. Default: second element of place_key or "y". show_stats : bool, optional Whether to show statistics in title. Default: True. Returns ------- ax : matplotlib.axes.Axes Axes with the plot. Raises ------ KeyError If cell_ind or place_key not found in stats_table IndexError If cell_ind >= number of neurons ValueError If place data is not 2D AttributeError If required attributes missing from experiment Notes ----- - Uses log-transformed calcium data for color mapping - Figure aspect ratio automatically adjusted based on spatial extent - Stats (MI and p-value) retrieved from experiment's stats_table - Supports both tuple place keys like ("x", "y") and single feature names Examples -------- >>> import matplotlib.pyplot as plt >>> from driada.experiment import load_demo_experiment >>> >>> # Basic place cell plot with position data >>> exp = load_demo_experiment(verbose=False) >>> # Use x_pos and y_pos which are available in demo data >>> # Note: Demo data doesn't have spike data, so we disable spike display >>> ax = plot_pc_activity(exp, cell_ind=5, place_key=("x_pos", "y_pos"), ... show_spikes=False, show_stats=False) >>> plt.close() # Suppress display >>> >>> # Using separate x,y features with custom styling >>> ax = plot_pc_activity(exp, cell_ind=10, place_key=("x_pos", "y_pos"), ... show_trajectory=True, cmap='viridis', ... marker_color='red', show_stats=False, ... show_spikes=False, ds=20) >>> plt.close() >>> >>> # Custom styling example >>> ax = plot_pc_activity(exp, cell_ind=3, place_key=("x_pos", "y_pos"), ... ds=10, show_stats=False, show_spikes=False) >>> plt.close() """ # Validate inputs if cell_ind < 0 or cell_ind >= exp.n_cells: raise IndexError(f"cell_ind {cell_ind} out of range [0, {exp.n_cells})") # Get spatial data if isinstance(place_key, tuple) and len(place_key) == 2: x_key, y_key = place_key x_data = getattr(exp, x_key).data y_data = getattr(exp, y_key).data default_xlabel = x_key default_ylabel = y_key else: # Single feature - try to get x,y components place_feat = getattr(exp, place_key) if hasattr(place_feat, "data") and place_feat.data.ndim == 2: x_data = place_feat.data[:, 0] y_data = place_feat.data[:, 1] default_xlabel = f"{place_key}_x" default_ylabel = f"{place_key}_y" else: raise ValueError(f"Place feature {place_key} must be 2D or provide tuple of features") # Get statistics if available if show_stats and hasattr(exp, "stats_table") and place_key in exp.stats_table: try: pc_stats = exp.stats_table[place_key][cell_ind] pval = pc_stats.get("pval", None) rel_mi = pc_stats.get("rel_me_beh", pc_stats.get("rel_mi_beh", None)) except (KeyError, IndexError): pval = None rel_mi = None show_stats = False else: show_stats = False # Create figure if needed if ax is None: lenx = max(x_data) - min(x_data) leny = max(y_data) - min(y_data) xyratio = max(lenx / leny, leny / lenx) fig, ax = create_default_figure(figsize=(figsize_base * xyratio, figsize_base)) # Get neural activity neur = rescale(np.log(exp.neurons[cell_ind].ca.data + 1e-10)) # Plot trajectory if requested if show_trajectory: ax.plot(x_data[::ds], y_data[::ds], c=trajectory_color, alpha=trajectory_alpha, zorder=1) # Plot activity ax.scatter(x_data[::ds], y_data[::ds], c=neur[::ds], cmap=cmap, alpha=scatter_alpha, zorder=2) # Plot spikes if requested if show_spikes and hasattr(exp.neurons[cell_ind], "sp"): spinds = np.where(exp.neurons[cell_ind].sp.data != 0)[0] if len(spinds) > 0: ax.scatter( x_data[spinds], y_data[spinds], c=marker_color, alpha=1, marker=marker_style, linewidth=2, s=marker_size, zorder=3, ) # Labels ax.set_xlabel(xlabel or default_xlabel) ax.set_ylabel(ylabel or default_ylabel) # Title if show_stats and pval is not None and rel_mi is not None: title = title_format.format(cell_ind=cell_ind, rel_mi=rel_mi, pval=pval) else: title = f"Cell {cell_ind}" ax.set_title(title) return ax
[docs] def plot_neuron_feature_density( exp, data_type, cell_id, featname, ind1=0, ind2=100000, ds=1, shift=None, ax=None, compute_wsd=False, ): """ Plot density distribution of neural activity conditioned on feature values. Parameters ---------- exp : Experiment Experiment object containing neurons and features. data_type : str Type of neural data: 'calcium' or 'spikes'. cell_id : int Index of the neuron. featname : str Name of the behavioral feature. ind1 : int, optional Start frame index. Default: 0. ind2 : int, optional End frame index. Default: 100000. ds : int, optional Downsampling factor. Default: 1. shift : int, optional Temporal shift in frames. Currently not implemented. Default: None. ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. compute_wsd : bool, optional Whether to compute Wasserstein distance for binary features. Default: False. Returns ------- ax : matplotlib.axes.Axes Axes with the plot. Raises ------ NotImplementedError If data_type='spikes' with binary feature IndexError If cell_id >= number of neurons AttributeError If feature not found in experiment Notes ----- - Binary features: Uses KDE with bw_adjust=0.5, log10 transform - Continuous features: Adds 1e-8 noise, uses 100x100 grid - Uses .scdata attribute for scaled data access - shift parameter is accepted but not used Examples -------- >>> import matplotlib.pyplot as plt >>> from driada.experiment import load_demo_experiment >>> >>> # Plot calcium vs feature density >>> exp = load_demo_experiment(verbose=False) >>> # Using 'speed' feature which is available in demo data >>> ax = plot_neuron_feature_density(exp, 'calcium', 5, 'speed') >>> plt.close() # Suppress display >>> >>> # For binary features (if available in your data) >>> # ax = plot_neuron_feature_density(exp, 'calcium', 10, 'licking', >>> # compute_wsd=True) """ ind2 = min(exp.n_frames, ind2) if data_type == "calcium": sig = exp.neurons[cell_id].ca.scdata[ind1:ind2][::ds] if data_type == "spikes": sig = exp.neurons[cell_id].sp.scdata[ind1:ind2][::ds] feature = getattr(exp, featname) bdata = feature.scdata[ind1:ind2][::ds] rbdata = rescale(rankdata(bdata)) if ax is None: fig, ax = plt.subplots(figsize=(6, 6)) if feature.is_binary: if data_type == "calcium": vals0 = np.log10(sig[np.where((rbdata == min(rbdata)) & (sig > 0))]) vals1 = np.log10(sig[np.where((rbdata == max(rbdata)) & (sig > 0))]) if compute_wsd and len(vals0) > 0 and len(vals1) > 0: wsd = wasserstein_distance(vals0, vals1) title_text = f"wsd={wsd:.3f}" else: title_text = "" _ = sns.kdeplot(vals0, ax=ax, c="b", label=f"{featname}=0", linewidth=3, bw_adjust=0.5) _ = sns.kdeplot(vals1, ax=ax, c="r", label=f"{featname}=1", linewidth=3, bw_adjust=0.5) ax.legend(loc="upper right") ax.set_xlabel("log(dF/F)", fontsize=20) ax.set_ylabel("density", fontsize=20) if title_text: ax.set_title(title_text) if data_type == "spikes": raise NotImplementedError( "Binary feature density plot for spike data not yet implemented" ) else: x0, y0 = np.log10(sig + np.random.random(size=len(sig)) * 1e-8), np.log( bdata + np.random.random(size=len(bdata)) * 1e-8 ) jdata = np.vstack([x0, y0]).T # jplot = sns.jointplot(jdata, x=jdata[:,0], y=jdata[:,1], kind='hist', bins=100) nbins = 100 k = gaussian_kde(jdata.T) xi, yi = np.mgrid[x0.min() : x0.max() : nbins * 1j, y0.min() : y0.max() : nbins * 1j] zi = k(np.vstack([xi.flatten(), yi.flatten()])) # plot a density ax.set_title("Density") ax.pcolormesh(xi, yi, zi.reshape(xi.shape), shading="auto", cmap="coolwarm") ax.set_xlabel("log(signals)", fontsize=20) ax.set_ylabel(f"log({featname})", fontsize=20) return ax
[docs] def plot_shadowed_groups( ax, xvals, binary_series, color="gray", alpha=0.3, label="shadowed", ymin=0.0, ymax=1.0, ): """ Shade regions where binary series equals 1. Parameters ---------- ax : matplotlib.axes.Axes Axes to plot on. xvals : array-like X-axis values corresponding to binary_series. binary_series : array-like Binary array (0s and 1s) indicating regions to shade. color : str, optional Color for shaded regions. Default: 'gray'. alpha : float, optional Transparency of shaded regions. Default: 0.3. label : str, optional Label for legend. Default: 'shadowed'. ymin : float, optional Lower y extent in axes coordinates (0–1). Default: 0.0. ymax : float, optional Upper y extent in axes coordinates (0–1). Default: 1.0. Returns ------- ax : matplotlib.axes.Axes Modified axes object. """ x = np.arange(len(binary_series)) # Find and shadow groups of 1s i = 0 n = len(binary_series) labelled = False while i < n: if binary_series[i] == 1: x_min = xvals[i] # Start of group while i < n and binary_series[i] == 1: i += 1 x_max = xvals[i - 1] if i > 0 else xvals[0] # End of group # Shadow the region ax.axvspan( x_min, x_max + 1, ymin=ymin, ymax=ymax, alpha=alpha, color=color, label=label if not labelled else "", ) labelled = True else: i += 1 return ax
[docs] def plot_neuron_feature_pair( exp, cell_id, featname, ind1=0, ind2=100000, ds=1, add_density_plot=True, ax=None, axs=None, style=None, panel_size=None, skip_tight_layout=False, title=None, bcolor="g", neuron_label=None, feature_label=None, non_feature_label=None, ): """ Plot neural activity alongside behavioral feature with density analysis. Parameters ---------- exp : Experiment Experiment object containing neurons and features. cell_id : int Index of the neuron. featname : str Name of the behavioral feature. ind1 : int, optional Start frame index. Default: 0. ind2 : int, optional End frame index. Default: 100000. ds : int, optional Downsampling factor. Default: 1. add_density_plot : bool, optional Whether to add density subplot. Default: True. ax : matplotlib.axes.Axes, optional Single axis to plot on. Forces add_density_plot=False when provided. axs : tuple of matplotlib.axes.Axes, optional Tuple of (ax_timeseries, ax_density) for external 2-panel layout. Enables density plot with external axes. style : StylePreset, optional Style preset from publication framework. If None, uses StylePreset.from_make_beautiful() for backward compatibility. panel_size : tuple of float, optional (width, height) in cm for style scaling. Default: (20, 6). skip_tight_layout : bool, optional If True, skip plt.tight_layout() call. Useful for subfigure layouts. Default: False. title : str, optional Custom title for the plot. bcolor : str, optional Color for feature visualization. Default: 'g'. neuron_label : str, optional Custom label for neuron. Default: f'neuron {cell_id}'. feature_label : str, optional Custom label for feature. Default: featname. non_feature_label : str, optional Custom label for non-feature state. Default: f'non-{featname}'. Returns ------- fig : matplotlib.figure.Figure Figure containing the plot(s). Raises ------ IndexError If cell_id >= number of neurons AttributeError If featname not found in experiment Notes ----- - Discrete features shown as shaded regions where active - Uses StylePreset for axis styling with legends below - Y-axis formatted to 1 decimal place - Dark gray for non-feature distribution in density plot - Calls plt.tight_layout() with bottom adjustment for legends (unless skip_tight_layout=True) Examples -------- >>> import matplotlib.pyplot as plt >>> from driada.experiment import load_demo_experiment >>> >>> # Basic time series plot >>> exp = load_demo_experiment(verbose=False) >>> fig = plot_neuron_feature_pair(exp, 5, 'speed') >>> plt.close(fig) # Suppress display >>> >>> # With custom labels >>> fig = plot_neuron_feature_pair(exp, 10, 'object1', ... feature_label='Object Interaction', ... non_feature_label='No Object') >>> plt.close(fig) """ from ..utils.publication import StylePreset # Default to standalone preset if not provided if style is None: style = StylePreset.from_make_beautiful() # Default panel size for standalone if panel_size is None: panel_size = (20, 6) # Get scaled style values line_width = style.get_line_width(panel_size, 'cm') kde_line_width = line_width * 1.5 ind2 = min(exp.n_frames, ind2) ca = exp.neurons[cell_id].ca.scdata[ind1:ind2][::ds] rca = rescale(rankdata(ca)) feature = getattr(exp, featname) if feature.data.ndim > 1: raise ValueError( f"plot_neuron_feature_pair requires a 1D feature, but " f"'{featname}' has shape {feature.data.shape}. " f"Use the original feature name (e.g., 'head_direction' " f"instead of 'head_direction_2d')." ) bdata = feature.scdata[ind1:ind2][::ds] rbdata = rescale(rankdata(bdata)) # Set default labels if not provided if neuron_label is None: neuron_label = f"neuron {cell_id}" if feature_label is None: feature_label = featname if non_feature_label is None: non_feature_label = f"non-{featname}" # Handle axes: axs tuple > ax single > create new if axs is not None: # External 2-panel axes provided ax0, ax1 = axs fig = ax0.figure elif ax is not None: # Single axis provided - no density plot ax0 = ax ax1 = None add_density_plot = False fig = ax0.figure else: # Create new figure if add_density_plot: fig, axes_tuple = plt.subplots(1, 2, figsize=(20, 6), width_ratios=[0.7, 0.3], dpi=300) ax0, ax1 = axes_tuple else: fig, ax0 = plt.subplots(figsize=(20, 6), dpi=300) ax1 = None xvals = np.arange(ind1, ind2)[::ds] / 20.0 ax0.plot(xvals, ca, c="b", linewidth=line_width, alpha=0.6, label=neuron_label) if feature.discrete: ax0 = plot_shadowed_groups( ax0, xvals, feature.scdata[ind1:ind2][::ds], color=bcolor, alpha=0.5, label=feature_label, ) else: ax0.plot(xvals, rbdata, c="r", linewidth=line_width * 0.7, alpha=0.5, label=feature_label) # Set axis labels ax0.set_xlabel("time, s") ax0.set_ylabel("signal") if title is None: title = f"{exp.signature} Neuron {cell_id}, feature {featname}" # Apply framework styling style.apply_to_axes(ax0, panel_size, 'cm') # Legend for time series ax0.legend( fontsize=style.base_legend_fontsize, loc="upper center", bbox_to_anchor=(0.5, -0.25), ncol=1, frameon=style.legend_frameon, ) # Format y-axis tick labels to 1 decimal place with proper rounding ax0.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f"{round(x, 6):.1f}")) if add_density_plot and ax1 is not None: if feature.discrete: vals0 = np.log10(ca[np.where((rbdata == min(rbdata)) & (ca > 0))] + 1e-10) vals1 = np.log10(ca[np.where((rbdata == max(rbdata)) & (ca > 0))] + 1e-10) if len(vals0) > 0 and len(vals1) > 0: wsd = wasserstein_distance(vals0, vals1) _ = sns.kdeplot( vals0, ax=ax1, c="dimgray", label=non_feature_label, linewidth=kde_line_width, bw_adjust=0.1 ) _ = sns.kdeplot( vals1, ax=ax1, c=bcolor, label=feature_label, linewidth=kde_line_width, bw_adjust=0.1 ) ax1.set_xlabel(r"$\log$(signal)") ax1.set_ylabel("density") # Apply framework styling style.apply_to_axes(ax1, panel_size, 'cm') ax1.legend( loc="upper center", bbox_to_anchor=(0.5, -0.25), fontsize=style.base_legend_fontsize, ncol=1, frameon=style.legend_frameon, ) ax1.set_xlim(-4.0, 0.5) # Format y-axis tick labels to 1 decimal place for density plot ax1.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f"{round(x, 6):.1f}")) else: x0 = np.log10(ca + np.random.random(size=len(ca)) * 1e-8) y0 = np.log(bdata + np.random.random(size=len(bdata)) * 1e-8) jdata = np.vstack([x0, y0]).T nbins = 100 k = gaussian_kde(jdata.T) xi, yi = np.mgrid[x0.min() : x0.max() : nbins * 1j, y0.min() : y0.max() : nbins * 1j] zi = k(np.vstack([xi.flatten(), yi.flatten()])) # plot a density ax1.pcolormesh(xi, yi, zi.reshape(xi.shape), shading="auto", cmap="coolwarm") ax1.set_xlabel(r"$\log$(signal)") ax1.set_ylabel(rf"$\log$({featname})") # Apply framework styling style.apply_to_axes(ax1, panel_size, 'cm') # Format y-axis tick labels to 1 decimal place for density plot ax1.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f"{x:.1f}")) if not skip_tight_layout: plt.tight_layout() # Add extra space at bottom for legends plt.subplots_adjust(bottom=0.15) return fig
[docs] def plot_disentanglement_heatmap( disent_matrix, count_matrix, feat_names, title=None, figsize=(12, 10), dpi=100, cmap=None, vmin=0, vmax=100, cbar_label="Disentanglement score (%)", fontsize=14, title_fontsize=18, show_grid=True, grid_alpha=0.3, ): """Plot disentanglement analysis results as a heatmap. Creates a heatmap showing the relative disentanglement scores between feature pairs. Each cell (i,j) shows the percentage of neurons where feature i was primary when paired with feature j. Parameters ---------- disent_matrix : ndarray Disentanglement matrix from disentangle_all_selectivities. count_matrix : ndarray Count matrix from disentangle_all_selectivities. feat_names : list of str Feature names corresponding to matrix indices. title : str, optional Plot title. Default: 'Disentanglement Analysis'. figsize : tuple, optional Figure size (width, height). Default: (12, 10). dpi : int, optional Figure DPI. Default: 100. cmap : str or Colormap, optional Colormap to use. Default: custom red-white-green gradient. vmin : float, optional Minimum value for colormap. Default: 0. vmax : float, optional Maximum value for colormap. Default: 100. cbar_label : str, optional Colorbar label. Default: 'Disentanglement score (%)'. fontsize : int, optional Font size for tick labels. Default: 14. title_fontsize : int, optional Font size for title. Default: 18. show_grid : bool, optional Whether to show grid lines. Default: True. grid_alpha : float, optional Grid transparency. Default: 0.3. Returns ------- fig : matplotlib.figure.Figure Figure containing the heatmap. ax : matplotlib.axes.Axes Axes containing the heatmap. Raises ------ ImportError If seaborn, pandas, or matplotlib.colors not available ValueError If matrix dimensions don't match or feat_names length doesn't match matrices Notes ----- The heatmap uses a diverging colormap where: - Red indicates low disentanglement (feature is redundant) - Gray (0.7, 0.7, 0.7) indicates balanced contribution (~50%) - Green indicates high disentanglement (feature is primary) Cells are masked (shown in white) where no data is available. Uses pandas DataFrame internally for seaborn compatibility. Calls plt.tight_layout() which affects figure state. Examples -------- >>> import numpy as np >>> import matplotlib.pyplot as plt >>> >>> # Create synthetic data for demonstration >>> n_features = 4 >>> features = ['speed', 'position', 'direction', 'licking'] >>> >>> # Create synthetic matrices >>> # disent_matrix[i,j] = how many times feature i was primary vs j >>> disent_mat = np.array([ ... [0, 15, 8, 20], ... [5, 0, 12, 18], ... [12, 8, 0, 10], ... [10, 7, 15, 0] ... ]) >>> >>> # count_matrix[i,j] = total comparisons between features i and j >>> count_mat = np.array([ ... [0, 20, 20, 30], ... [20, 0, 20, 25], ... [20, 20, 0, 25], ... [30, 25, 25, 0] ... ]) >>> >>> # Basic heatmap >>> fig, ax = plot_disentanglement_heatmap(disent_mat, count_mat, features) >>> plt.close(fig) # Suppress display >>> >>> # Custom styling >>> fig, ax = plot_disentanglement_heatmap( ... disent_mat, count_mat, features, ... title="My Analysis", cmap='RdYlGn', ... figsize=(8, 6), dpi=150 ... ) >>> plt.close(fig) """ import seaborn as sns from matplotlib.colors import LinearSegmentedColormap import pandas as pd # Calculate relative disentanglement matrix (as percentage) with np.errstate(divide="ignore", invalid="ignore"): rel_disent_matrix = np.divide(disent_matrix, count_matrix) * 100 rel_disent_matrix[count_matrix == 0] = np.nan # Create default colormap if not provided if cmap is None: # Red -> Gray -> Green gradient # Gray at 50% represents equal selectivity (no disentanglement) colors = [(1, 0, 0), (0.7, 0.7, 0.7), (0, 1, 0)] n_bins = 100 cmap = LinearSegmentedColormap.from_list("disentanglement_cmap", colors, N=n_bins) # Create DataFrame for seaborn df_heatmap = pd.DataFrame(rel_disent_matrix, columns=feat_names, index=feat_names) # Create figure fig, ax = plt.subplots(figsize=figsize, dpi=dpi) # Create heatmap sns.heatmap( df_heatmap, ax=ax, cmap=cmap, vmin=vmin, vmax=vmax, cbar_kws={"label": cbar_label}, mask=np.isnan(rel_disent_matrix), square=True, linewidths=0.5, linecolor="gray", ) # Add grid if requested if show_grid: ax.grid(True, linestyle="-", alpha=grid_alpha, color="black") # Set title if title is None: title = "Disentanglement Analysis" ax.set_title(title, fontsize=title_fontsize, pad=20) # Configure tick labels ax.set_xticks(np.arange(len(feat_names)) + 0.5) ax.set_xticklabels(feat_names, fontsize=fontsize, rotation=45, ha="right") ax.set_yticks(np.arange(len(feat_names)) + 0.5) ax.set_yticklabels(feat_names, fontsize=fontsize, rotation=0) # Set axis labels ax.set_xlabel("Feature (as secondary)", fontsize=fontsize + 2) ax.set_ylabel("Feature (as primary)", fontsize=fontsize + 2) plt.tight_layout() return fig, ax
[docs] def plot_disentanglement_summary( disent_matrix, count_matrix, feat_names, experiments=None, title_prefix="", figsize=(14, 10), dpi=100, ): """Plot comprehensive disentanglement analysis with multiple views. Creates a figure with multiple subplots showing: 1. Disentanglement heatmap 2. Feature dominance scores 3. Pairwise interaction counts Parameters ---------- disent_matrix : ndarray or list of ndarray Disentanglement matrix(es). If list, matrices are summed. count_matrix : ndarray or list of ndarray Count matrix(es). If list, matrices are summed. feat_names : list of str Feature names corresponding to matrix indices. experiments : list of str, optional Experiment names if multiple matrices provided. Currently not used. title_prefix : str, optional Prefix for the main title. figsize : tuple, optional Figure size. Default: (14, 10). dpi : int, optional Figure DPI. Default: 100. Returns ------- fig : matplotlib.figure.Figure Figure containing all subplots. Raises ------ ImportError If matplotlib.colors, seaborn, or pandas not available ValueError If matrix dimensions don't match or feat_names length doesn't match matrices TypeError If disent_matrix/count_matrix not ndarray or list of ndarrays Notes ----- - Creates 2x2 grid layout with custom ratios (3:1 for both dimensions) - Main heatmap uses red-white-green colormap (different from plot_disentanglement_heatmap) - Dominance scores show how often each feature is primary - Only displays feature pairs with non-zero counts - experiments parameter is accepted but not used in current implementation - Calls plt.tight_layout() which affects figure state Examples -------- >>> import numpy as np >>> import matplotlib.pyplot as plt >>> >>> # Create synthetic data >>> features = ['speed', 'position', 'direction', 'licking'] >>> >>> # Synthetic matrices as before >>> disent_mat = np.array([ ... [0, 15, 8, 20], ... [5, 0, 12, 18], ... [12, 8, 0, 10], ... [10, 7, 15, 0] ... ]) >>> count_mat = np.array([ ... [0, 20, 20, 30], ... [20, 0, 20, 25], ... [20, 20, 0, 25], ... [30, 25, 25, 0] ... ]) >>> >>> # Single experiment summary >>> fig = plot_disentanglement_summary(disent_mat, count_mat, features) >>> plt.close(fig) # Suppress display >>> >>> # Multiple experiments (matrices will be summed) >>> disent2 = disent_mat * 0.8 # Second synthetic experiment >>> count2 = count_mat # Same comparison counts >>> fig = plot_disentanglement_summary( ... [disent_mat, disent2], [count_mat, count2], features, ... title_prefix="Combined: " ... ) >>> plt.close(fig) """ # Handle multiple experiments if isinstance(disent_matrix, list): total_disent = np.sum(disent_matrix, axis=0) total_count = np.sum(count_matrix, axis=0) n_exp = len(disent_matrix) else: total_disent = disent_matrix total_count = count_matrix n_exp = 1 # Create figure with subplots fig = plt.figure(figsize=figsize, dpi=dpi) gs = fig.add_gridspec(2, 2, height_ratios=[3, 1], width_ratios=[3, 1]) # Main heatmap ax_main = fig.add_subplot(gs[0, 0]) # Calculate relative disentanglement matrix with np.errstate(divide="ignore", invalid="ignore"): rel_disent_matrix = np.divide(total_disent, total_count) * 100 rel_disent_matrix[total_count == 0] = np.nan # Create colormap from matplotlib.colors import LinearSegmentedColormap import seaborn as sns import pandas as pd colors = [(1, 0, 0), (1, 1, 1), (0, 1, 0)] cmap = LinearSegmentedColormap.from_list("disentanglement_cmap", colors, N=100) # Create DataFrame and plot df_heatmap = pd.DataFrame(rel_disent_matrix, columns=feat_names, index=feat_names) sns.heatmap( df_heatmap, ax=ax_main, cmap=cmap, vmin=0, vmax=100, cbar_kws={"label": "Disentanglement score (%)"}, mask=np.isnan(rel_disent_matrix), square=True, linewidths=0.5, linecolor="gray", ) ax_main.set_title("Disentanglement Heatmap") # Feature dominance scores (how often each feature is primary) ax_dom = fig.add_subplot(gs[0, 1]) with np.errstate(divide="ignore", invalid="ignore"): dominance_scores = np.nansum(total_disent / total_count, axis=1) y_pos = np.arange(len(feat_names)) ax_dom.barh(y_pos, dominance_scores, color="green", alpha=0.7) ax_dom.set_yticks(y_pos) ax_dom.set_yticklabels(feat_names) ax_dom.set_xlabel("Dominance Score") ax_dom.set_title("Feature Dominance") ax_dom.grid(True, alpha=0.3) # Interaction counts ax_counts = fig.add_subplot(gs[1, :]) pair_counts = [] pair_labels = [] for i in range(len(feat_names)): for j in range(i + 1, len(feat_names)): if total_count[i, j] > 0: pair_counts.append(total_count[i, j]) pair_labels.append(f"{feat_names[i]}-{feat_names[j]}") x_pos = np.arange(len(pair_counts)) ax_counts.bar(x_pos, pair_counts, color="blue", alpha=0.7) ax_counts.set_xticks(x_pos) ax_counts.set_xticklabels(pair_labels, rotation=45, ha="right") ax_counts.set_ylabel("Number of neurons") ax_counts.set_title("Pairwise interaction counts") ax_counts.grid(True, axis="y", alpha=0.3) # Main title if n_exp > 1: title = f"{title_prefix}Disentanglement Analysis ({n_exp} experiments)" else: title = f"{title_prefix}Disentanglement Analysis" fig.suptitle(title, fontsize=16, y=0.98) plt.tight_layout() return fig
[docs] def plot_selectivity_heatmap( exp, significant_neurons, metric="mi", cmap="viridis", use_log_scale=False, vmin=None, vmax=None, figsize=(10, 8), significance_threshold=None, ax=None, ): """Create a heatmap showing metric values for selective neuron-feature pairs. Parameters ---------- exp : Experiment The experiment object containing all data and results significant_neurons : dict Dictionary mapping neuron IDs to lists of significant features metric : str, optional Which metric to display ('mi' for mutual information, 'corr' for correlation) Default: 'mi' cmap : str, optional Colormap to use. Default: 'viridis' use_log_scale : bool, optional Whether to use log scale for metric values. Default: False vmin : float, optional Minimum value for colormap. If None, auto-determined from data vmax : float, optional Maximum value for colormap. If None, auto-determined from data figsize : tuple, optional Figure size (ignored if ax provided). Default: (10, 8) significance_threshold : float, optional If provided, only show pairs with p-value below this threshold ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. Returns ------- fig : matplotlib.figure.Figure Figure containing the heatmap ax : matplotlib.axes.Axes Axes containing the heatmap stats : dict Dictionary containing statistics about the data: - n_selective: number of selective neurons - n_pairs: total number of significant pairs - selectivity_rate: percentage of selective neurons - metric_values: list of all non-zero metric values - sparsity: percentage of zero entries in the matrix Raises ------ AttributeError If experiment missing required attributes (dynamic_features, n_cells, get_neuron_feature_pair_stats) KeyError If neuron or feature not found in experiment data Notes ----- - Only processes string-type features (tuple features are ignored) - Always uses mode='calcium' when retrieving stats - Calls plt.tight_layout() which affects figure state Examples -------- >>> import matplotlib.pyplot as plt >>> from driada.experiment import load_demo_experiment >>> >>> # Load demo experiment and create synthetic selectivity data >>> exp = load_demo_experiment(verbose=False) >>> >>> # Create synthetic significant_neurons dict >>> # In real usage, this comes from INTENSE analysis >>> significant_neurons = { ... 5: ['speed', 'x_pos'], # Neuron 5 selective for speed and x_pos ... 10: ['speed'], # Neuron 10 selective for speed only ... 15: ['x_pos', 'y_pos'], # Neuron 15 selective for spatial features ... 20: ['speed', 'y_pos'], ... 25: ['x_pos'] ... } >>> >>> # Initialize stats_tables if not present >>> if not hasattr(exp, 'stats_tables'): ... exp.stats_tables = {} >>> if 'calcium' not in exp.stats_tables: ... exp.stats_tables['calcium'] = {} >>> >>> # Add minimal stats to experiment for the example >>> # Using features that exist in demo data >>> # Each stat entry needs data_hash and other required fields >>> import numpy as np >>> hash_val = 'demo_hash' >>> exp.stats_tables['calcium']['speed'] = { ... 5: {'me': 0.3, 'pval': 0.001, 'rval': 0.5, 'data_hash': hash_val, ... 'opt_delay': 0, 'pre_pval': 0.1, 'pre_rval': 0.3, ... 'rel_me_beh': 0.2, 'rel_me_ca': 0.15}, ... 10: {'me': 0.4, 'pval': 0.0001, 'rval': 0.6, 'data_hash': hash_val, ... 'opt_delay': 0, 'pre_pval': 0.05, 'pre_rval': 0.4, ... 'rel_me_beh': 0.3, 'rel_me_ca': 0.2}, ... 20: {'me': 0.25, 'pval': 0.005, 'rval': 0.4, 'data_hash': hash_val, ... 'opt_delay': 0, 'pre_pval': 0.2, 'pre_rval': 0.25, ... 'rel_me_beh': 0.15, 'rel_me_ca': 0.1} ... } >>> exp.stats_tables['calcium']['x_pos'] = { ... 5: {'me': 0.35, 'pval': 0.002, 'rval': 0.55, 'data_hash': hash_val, ... 'opt_delay': 0, 'pre_pval': 0.15, 'pre_rval': 0.35, ... 'rel_me_beh': 0.25, 'rel_me_ca': 0.2}, ... 15: {'me': 0.45, 'pval': 0.0001, 'rval': 0.7, 'data_hash': hash_val, ... 'opt_delay': 0, 'pre_pval': 0.08, 'pre_rval': 0.5, ... 'rel_me_beh': 0.35, 'rel_me_ca': 0.3}, ... 25: {'me': 0.3, 'pval': 0.003, 'rval': 0.5, 'data_hash': hash_val, ... 'opt_delay': 0, 'pre_pval': 0.18, 'pre_rval': 0.3, ... 'rel_me_beh': 0.2, 'rel_me_ca': 0.15} ... } >>> exp.stats_tables['calcium']['y_pos'] = { ... 15: {'me': 0.2, 'pval': 0.01, 'rval': 0.3, 'data_hash': hash_val, ... 'opt_delay': 0, 'pre_pval': 0.25, 'pre_rval': 0.15, ... 'rel_me_beh': 0.1, 'rel_me_ca': 0.08}, ... 20: {'me': 0.15, 'pval': 0.02, 'rval': 0.25, 'data_hash': hash_val, ... 'opt_delay': 0, 'pre_pval': 0.3, 'pre_rval': 0.12, ... 'rel_me_beh': 0.08, 'rel_me_ca': 0.05} ... } >>> >>> # Basic selectivity heatmap >>> fig, ax, stats = plot_selectivity_heatmap(exp, significant_neurons) >>> plt.close(fig) # Suppress display >>> >>> # With log scale and p-value filtering >>> fig, ax, stats = plot_selectivity_heatmap( ... exp, significant_neurons, ... use_log_scale=True, ... significance_threshold=0.01 ... ) >>> plt.close(fig) >>> >>> # Custom visualization >>> fig, ax, stats = plot_selectivity_heatmap( ... exp, significant_neurons, ... cmap='hot', vmin=0, vmax=0.5, ... figsize=(12, 10) ... ) >>> plt.close(fig) """ # Get all features and create ordered lists all_features = sorted([f for f in exp.dynamic_features.keys() if isinstance(f, str)]) all_neurons = list(range(exp.n_cells)) # Create matrix with metric values (0 for non-selective pairs) selectivity_matrix = np.zeros((len(all_neurons), len(all_features))) # Collect all metric values for statistics all_metric_values = [] for neuron_idx, cell_id in enumerate(all_neurons): for feat_idx, feat_name in enumerate(all_features): # Check if this neuron-feature pair is significant if cell_id in significant_neurons and feat_name in significant_neurons[cell_id]: # Get the statistics for this pair try: pair_stats = exp.get_neuron_feature_pair_stats( cell_id, feat_name, mode="calcium" ) # Skip if stats not available if pair_stats is None: continue # Check significance threshold if provided if significance_threshold is not None: pval = pair_stats.get("pval", None) # Skip if pval is None (failed stage 1) or above threshold if pval is None or pval > significance_threshold: continue # Get the metric value - 'me' contains the metric value for whichever metric was used value = pair_stats.get("me", 0) selectivity_matrix[neuron_idx, feat_idx] = value all_metric_values.append(value) except (KeyError, AttributeError): # Skip if stats not available for this pair continue # Apply log scale if requested if use_log_scale and len(all_metric_values) > 0: # Add small epsilon to avoid log(0) epsilon = 1e-10 selectivity_matrix = np.log10(selectivity_matrix + epsilon) # Set zeros back to a special value for visualization selectivity_matrix[selectivity_matrix < np.log10(epsilon * 2)] = np.nan # Create figure if needed if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.figure # Determine color limits if len(all_metric_values) > 0: if vmin is None: vmin = 0 if not use_log_scale else np.log10(min(all_metric_values)) if vmax is None: vmax = max(all_metric_values) if not use_log_scale else np.log10(max(all_metric_values)) else: vmin = 0 vmax = 1 # Create masked array to handle NaN values properly masked_matrix = np.ma.masked_invalid(selectivity_matrix) # Plot heatmap im = ax.imshow( masked_matrix, cmap=cmap, aspect="auto", interpolation="nearest", vmin=vmin, vmax=vmax, ) # Set ticks and labels ax.set_xticks(range(len(all_features))) ax.set_xticklabels(all_features, rotation=45, ha="right") ax.set_yticks( range(0, len(all_neurons), max(1, len(all_neurons) // 20)) ) # Show ~20 neuron labels ax.set_yticklabels(range(0, len(all_neurons), max(1, len(all_neurons) // 20))) # Labels and title ax.set_xlabel("Features", fontsize=12) ax.set_ylabel("Neurons", fontsize=12) metric_name = "Mutual Information" if metric == "mi" else "Correlation" scale_text = " (log₁₀)" if use_log_scale else "" ax.set_title( f"Neuronal Selectivity: {metric_name}{scale_text}", fontsize=14, fontweight="bold", ) # Add colorbar with appropriate label cbar = plt.colorbar(im, ax=ax) cbar.set_label(f"{metric_name}{scale_text}", rotation=270, labelpad=20) # Calculate statistics n_selective = len(significant_neurons) n_pairs = sum(len(features) for features in significant_neurons.values()) selectivity_rate = (n_selective / exp.n_cells) * 100 sparsity = (1 - n_pairs / (len(all_neurons) * len(all_features))) * 100 # Add summary text summary_lines = [ f"Selective neurons: {n_selective}/{exp.n_cells} ({selectivity_rate:.1f}%)", f"Total selective pairs: {n_pairs}", ] if len(all_metric_values) > 0: summary_lines.extend( [ f"{metric.upper()} range: [{min(all_metric_values):.3f}, {max(all_metric_values):.3f}]", f"Mean {metric.upper()}: {np.mean(all_metric_values):.3f}", ] ) summary_text = "\n".join(summary_lines) # Position text in the lower right corner to avoid colorbar overlap fig.text( 0.98, 0.02, summary_text, transform=fig.transFigure, fontsize=10, verticalalignment="bottom", horizontalalignment="right", bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), ) # Add grid for better readability ax.set_xticks(np.arange(len(all_features)) - 0.5, minor=True) ax.set_yticks(np.arange(len(all_neurons)) - 0.5, minor=True) ax.grid(which="minor", color="gray", linestyle="-", linewidth=0.5, alpha=0.3) plt.tight_layout() # Return statistics stats = { "n_selective": n_selective, "n_pairs": n_pairs, "selectivity_rate": selectivity_rate, "metric_values": all_metric_values, "sparsity": sparsity, } return fig, ax, stats