Source code for driada.experiment.exp_base

import numpy as np
import warnings
import tqdm
import pickle
import logging
from typing import Optional, Union, List

from ..information.info_base import TimeSeries
from ..information.info_base import MultiTimeSeries
from .neuron import (
    DEFAULT_MIN_BEHAVIOUR_TIME,
    DEFAULT_T_OFF,
    DEFAULT_FPS,
    Neuron,
)
from .event_detection import CA_SHIFT_N_TOFF
from .event_detection import MIN_FEAT_SHIFT_SEC
from ..utils.data import get_hash, populate_nested_dict
from ..information.info_base import get_1d_mi
from ..intense.intense_base import get_multicomp_correction_thr

# Minimum neurons for parallelization to be beneficial
MIN_NEURONS_FOR_PARALLEL = 50

STATS_VARS = [
    "data_hash",
    "opt_delay",
    "pre_pval",
    "pre_rval",
    "pval",
    "rval",
    "me",
    "rel_me_beh",
    "rel_me_ca",
]
SIGNIFICANCE_VARS = [
    "stage1",
    "shuffles1",
    "stage2",
    "shuffles2",
    "final_p_thr",
    "multicomp_corr",
    "pairwise_pval_thr",
]
DEFAULT_STATS = dict(zip(STATS_VARS, [None for _ in STATS_VARS]))
DEFAULT_SIGNIFICANCE = dict(zip(SIGNIFICANCE_VARS, [None for _ in SIGNIFICANCE_VARS]))


def check_dynamic_features(dynamic_features):
    """Validate that all dynamic features have the same length.

    Parameters
    ----------
    dynamic_features : dict
        Dictionary mapping feature names (str) to feature data. Supported data types:
        - TimeSeries: Length determined by len(data) attribute
        - MultiTimeSeries: Length determined by n_points attribute
        - numpy.ndarray: Length is last dimension (shape[-1]). Must be at least 1D.
        Empty dict is allowed and will return without error.

    Returns
    -------
    None
        Function returns None. Validation is done via exceptions.

    Raises
    ------
    ValueError
        If features have different lengths. Error message includes detailed
        listing of each feature name and its length in timepoints.
    TypeError
        If a feature has an unsupported type (not TimeSeries, MultiTimeSeries,
        or numpy array), or if a numpy array is 0-dimensional (scalar).
        Error message includes the problematic feature name and its type.

    Notes
    -----
    For numpy arrays, the last dimension is always interpreted as the time
    dimension, consistent with the shape convention (n_features, n_timepoints)."""
    if not dynamic_features:
        return  # Handle empty features gracefully

    dfeat_lengths = {}
    for feat_id, current_ts in dynamic_features.items():
        # Only accept specific types - reject everything else
        if isinstance(current_ts, TimeSeries):
            len_ts = len(current_ts.data)
        elif isinstance(current_ts, MultiTimeSeries):
            # MultiTimeSeries inherits from MVData which has n_points attribute
            len_ts = current_ts.n_points
        elif isinstance(current_ts, np.ndarray):
            # Handle raw numpy arrays - last dimension is time
            if current_ts.ndim == 0:
                raise TypeError(
                    f"Feature '{feat_id}' is a scalar numpy array. "
                    f"Expected TimeSeries, MultiTimeSeries, or numpy array with at least 1 dimension."
                )
            len_ts = current_ts.shape[-1]
        else:
            # Reject all other types
            raise TypeError(
                f"Feature '{feat_id}' has unsupported type: {type(current_ts).__name__}. "
                f"Expected TimeSeries, MultiTimeSeries, or numpy array only."
            )

        dfeat_lengths[feat_id] = len_ts

    # Check all features have same length
    unique_lengths = set(dfeat_lengths.values())
    if len(unique_lengths) != 1:
        # Create informative error message
        length_info = [f"  {feat}: {length} timepoints" for feat, length in dfeat_lengths.items()]
        raise ValueError("Dynamic features have different lengths:\n" + "\n".join(length_info))


[docs] class Experiment: """Base class for calcium imaging and spike train experiments. This class provides a unified interface for analyzing neural activity data (calcium imaging or spike trains) in relation to various experimental features (behavioral variables, stimuli, etc.). It handles data organization, feature extraction, mutual information analysis, and statistical significance testing. Parameters ---------- signature : str Unique identifier for the experiment. calcium : numpy.ndarray Calcium imaging data of shape (n_neurons, n_timepoints). Required parameter. spikes : numpy.ndarray or None Spike train data. Can be provided directly or reconstructed from calcium. exp_identificators : dict Experiment metadata and identifiers. Each key-value pair becomes an attribute of the object (e.g., exp_identificators={'mouse_id': 'M1'} creates self.mouse_id = 'M1'). static_features : dict Time-invariant features (e.g., cell types, anatomical properties). Should include 'fps', 't_rise_sec', 't_off_sec' or defaults will be used. The dict is stored as self.static_features, and each key also becomes an individual attribute (e.g., static_features={'fps': 20} creates both self.static_features['fps'] and self.fps = 20). dynamic_features : dict Time-varying features (e.g., behavior, stimuli). Keys are feature names, values are TimeSeries, MultiTimeSeries, or numpy arrays. All features must have the same length (number of timepoints). **kwargs : dict Additional parameters including: - optimize_kinetics : bool or str, optimize kinetics per neuron (default: False). If True, uses 'lbfgs' method. Can specify 'lbfgs' or 'grid' explicitly. - reconstruct_spikes : str or bool, spike reconstruction method (default: 'wavelet') - bad_frames_mask : array-like, boolean mask where True indicates bad frames - spike_kwargs : dict, parameters for spike reconstruction - verbose : bool, print progress messages (default: True) Attributes ---------- signature : str Experiment identifier. neurons : list List of Neuron objects, indexed by cell ID (0-based). Access with self.neurons[cell_id]. n_cells : int Number of neurons in the experiment. n_frames : int Number of time points (frames) in the experiment. calcium : MultiTimeSeries Calcium imaging data as MultiTimeSeries object. spikes : MultiTimeSeries Spike train data as MultiTimeSeries object. static_features : dict Time-invariant experimental features as originally provided. dynamic_features : dict Time-varying experimental features as TimeSeries/MultiTimeSeries objects. stats_tables : dict Nested dict storing mutual information statistics. Structure: stats_tables[mode][feat_id][cell_id] = stats_dict. significance_tables : dict Nested dict storing statistical significance data. Structure: significance_tables[mode][feat_id][cell_id] = sig_dict. embeddings : dict Stored dimensionality reduction results by data type and method. Structure: embeddings[data_type][method_name] = embedding_data. verbose : bool Whether to print progress messages. spike_reconstruction_method : str or None Method used for spike reconstruction if applicable. filtered_flag : bool Whether bad frames were filtered out. selectivity_tables_initialized : bool Whether selectivity tables have been initialized. exp_identificators : dict Original experiment identifiers dictionary. _data_hashes : dict Private attribute storing hash representations for caching. _rdm_cache : dict Private attribute caching representational dissimilarity matrices. Methods ------- check_ds(ds) Validate downsampling rate for behavioral analysis. get_neuron_feature_pair_stats(cell_id, feat_id, mode='calcium') Get selectivity statistics for a neuron-feature pair. get_neuron_feature_pair_significance(cell_id, feat_id, mode='calcium') Get statistical significance data for a neuron-feature pair. update_neuron_feature_pair_stats(stats, cell_id, feat_id, mode='calcium', ...) Update statistics for a neuron-feature pair. update_neuron_feature_pair_significance(sig, cell_id, feat_id, mode='calcium') Update significance data for a neuron-feature pair. get_multicell_shuffled_calcium(cbunch=None, method='roll_based', **kwargs) Get shuffled calcium data for specified neurons. get_multicell_shuffled_spikes(cbunch=None, method='isi_based', **kwargs) Get shuffled spike data for specified neurons. get_stats_slice(cell_ids, feat_ids, mode='calcium', vars=None) Extract statistics for multiple neuron-feature pairs. get_significance_slice(cell_ids, feat_ids, mode='calcium', vars=None) Extract significance data for multiple neuron-feature pairs. get_feature_entropy(feat_id, ds=1) Calculate Shannon entropy of a feature. get_significant_neurons(min_nspec=1, cbunch=None, fbunch=None, mode='calcium', ...) Find neurons with significant selectivity to one or more features. store_embedding(embedding, method_name, data_type='calcium', metadata=None) Store dimensionality reduction results. get_embedding(method_name, data_type='calcium') Retrieve stored dimensionality reduction embedding. compute_rdm(items, activity_type='calcium', metric='correlation', **kwargs) Compute representational dissimilarity matrix. clear_rdm_cache() Clear the RDM computation cache. Notes ----- The class supports both calcium imaging and spike train analysis through the 'mode' parameter in various methods. Results are cached using hash-based lookups to avoid redundant computations. Statistical significance is determined using the INTENSE algorithm with two-stage hypothesis testing. Spike reconstruction is performed automatically if spikes are not provided and reconstruct_spikes is not False or None. The 'wavelet' method is recommended for calcium imaging data. Set reconstruct_spikes to False or None to disable spike reconstruction entirely. Individual static and dynamic features can be accessed as attributes. For example, if 'position' is a dynamic feature, access it via self.position. Protected attribute names will have an underscore prefix if conflicts occur.""" _PROTECTED_ATTRS = frozenset({ "neurons", "calcium", "spikes", "n_cells", "n_frames", "stats_tables", "significance_tables", "embeddings", "_rdm_cache", "_data_hashes", "signature", "exp_identificators", "static_features", "dynamic_features", })
[docs] def __init__( self, signature, calcium, spikes, exp_identificators, static_features, dynamic_features, ground_truth=None, **kwargs, ): """Initialize experiment with neural data and behavioral features. Creates an experiment object that integrates calcium imaging data, spike trains, and behavioral features for neural population analysis. Handles data validation, spike reconstruction, and sets up internal data structures for statistical analysis. Parameters ---------- signature : str Unique identifier for the experiment (e.g., 'mouse123_session1'). calcium : array-like Calcium imaging data with shape (n_cells, n_frames). Required. Each row is a neuron's calcium trace over time. spikes : array-like or None Spike train data with same shape as calcium. If None and reconstruct_spikes is specified in kwargs, spikes will be reconstructed from calcium data. exp_identificators : dict or None Metadata about the experiment (e.g., subject_id, session_date). Keys become attributes of the experiment object. static_features : dict or None Time-invariant features. Expected keys include: - 'fps': sampling rate (frames per second) - 't_rise_sec': calcium rise time in seconds - 't_off_sec': calcium decay time in seconds - Other experiment-wide parameters dynamic_features : dict or None Time-varying behavioral features (e.g., position, speed). Values should be array-like with time dimension matching n_frames. Keys become accessible via self.dynamic_features. ground_truth : dict or None, optional Ground truth for synthetic experiments. None for real data. For synthetic data, typically contains: - "expected_pairs": list of (neuron_idx, feature_name) tuples - "tuning_parameters": dict of per-neuron tuning params - "neuron_types": dict mapping neuron_idx to group name - "population_config": original population configuration **kwargs Additional parameters: - optimize_kinetics (bool or str): Optimize kinetics per neuron. Default False. If True, uses 'lbfgs' method. Can specify 'lbfgs' or 'grid' explicitly. - reconstruct_spikes (str, False, or None): Method for spike reconstruction. Options: 'wavelet' (default), 'threshold', False, or None. If False or None, spike reconstruction is disabled. Only used if spikes is None. - bad_frames_mask (array-like): Boolean mask of frames to exclude. - spike_kwargs (dict): Parameters for spike reconstruction method. - verbose (bool): Print progress messages. Default True. - create_circular_2d (bool): If True, automatically create `_2d` versions of circular features (detected via type_info.is_circular) as (cos, sin) MultiTimeSeries. Original features are preserved. E.g., 'headdirection' -> also creates 'headdirection_2d'. Default True. Raises ------ ValueError If calcium is None, if data shapes are inconsistent, if feature names conflict with protected attributes, or if data appears transposed (n_cells > n_frames). TypeError If dynamic features have incompatible types. Warnings -------- UserWarning If both spikes and reconstruct_spikes are provided (spikes will be overwritten), or if static feature names conflict with existing attributes (will be prefixed with underscore). Notes ----- Protected attribute names that cannot be used as feature names: 'spikes', 'calcium', 'neurons', 'n_cells', 'n_frames', 'static_features', 'dynamic_features', 'downsampling', 'significance_tables', 'stats_tables', '_data_hashes', 'embeddings', '_rdm_cache', 'intense_results' The initialization process: 1. Validates and stores basic data (calcium required) 2. Handles spike data or reconstruction 3. Creates Neuron objects for each cell 4. Processes static and dynamic features 5. Builds internal data structures for caching 6. Validates data consistency Examples -------- >>> import numpy as np >>> from driada.information.info_base import TimeSeries >>> >>> # Basic initialization with calcium data only >>> calcium_data = np.random.randn(10, 1000) # 10 neurons, 1000 timepoints >>> exp = Experiment('exp001', calcium_data, None, {}, {}, {}, ... reconstruct_spikes=None, verbose=False) >>> # With spikes and behavioral features >>> spike_data = np.random.poisson(0.05, (10, 1000)) # spike trains >>> speed_trace = TimeSeries(np.random.rand(1000)) # behavioral data >>> exp = Experiment( ... 'exp002', ... calcium_data, ... spike_data, ... {}, ... {'fps': 30.0}, ... {'speed': speed_trace}, ... verbose=False ... ) >>> # Without spike reconstruction (faster for doctests) >>> exp = Experiment( ... 'exp003', ... calcium_data, ... None, ... {}, ... {'fps': 30.0}, ... {}, ... reconstruct_spikes=None, ... verbose=False ... ) """ optimize_kinetics = kwargs.get("optimize_kinetics", False) reconstruct_spikes = kwargs.get("reconstruct_spikes", "wavelet") bad_frames_mask = kwargs.get("bad_frames_mask", None) spike_kwargs = kwargs.get("spike_kwargs", None) self.verbose = kwargs.get("verbose", True) create_circular_2d = kwargs.get("create_circular_2d", True) # Parallelization settings self._n_jobs = kwargs.get("n_jobs", -1) self._enable_parallelization = kwargs.get("enable_parallelization", False) # Extract new data fields asp = kwargs.get("asp", None) reconstructions = kwargs.get("reconstructions", None) metadata = kwargs.get("metadata", None) # Extract per-neuron metrics from metadata neuron_metrics_df = None if metadata is not None and 'metrics_df' in metadata: metrics_df = metadata['metrics_df'] # Handle edge case: metrics_df might have more entries than neurons # (e.g., BOWL: 495 metrics vs 370 neurons due to preprocessing filtering) # For now, we only support direct indexing (metrics_df length == n_neurons) n_neurons = calcium.shape[0] first_key = list(metrics_df.keys())[0] metrics_len = len(metrics_df[first_key]) if metrics_len == n_neurons: neuron_metrics_df = metrics_df elif self.verbose: warnings.warn( f"metrics_df length ({metrics_len}) doesn't match neuron count ({n_neurons}). " f"Per-neuron metrics will not be loaded. This can happen when the NPZ contains " f"all preprocessing components but only a subset was kept." ) check_dynamic_features(dynamic_features) self.exp_identificators = exp_identificators self.signature = signature self.ground_truth = ground_truth # None for real data, dict for synthetic for idx in exp_identificators: setattr(self, idx, exp_identificators[idx]) if calcium is None: raise ValueError( "Calcium data is required. Please provide a numpy array with shape (n_neurons, n_timepoints)." ) # Handle spike reconstruction based on reconstruct_spikes parameter if reconstruct_spikes is None or reconstruct_spikes is False: # No spike reconstruction requested if spikes is None and self.verbose: warnings.warn( "No spike data provided, spikes reconstruction from Ca2+ data disabled" ) else: # Spike reconstruction requested if spikes is not None: warnings.warn( f"Spike data will be overridden by reconstructed spikes from Ca2+ data with method={reconstruct_spikes}" ) # Store the reconstruction method for potential future use self.spike_reconstruction_method = reconstruct_spikes # Reconstruct spikes spikes = self._reconstruct_spikes( calcium, reconstruct_spikes, static_features.get("fps"), spike_kwargs ) self.filtered_flag = False # Only call _trim_data if there are actual bad frames to filter if bad_frames_mask is not None and np.any(bad_frames_mask): calcium, spikes, dynamic_features = self._trim_data( calcium, spikes, dynamic_features, bad_frames_mask ) else: for feat_id in dynamic_features.copy(): feat_data = dynamic_features[feat_id] # Skip if already a TimeSeries or MultiTimeSeries if not isinstance(feat_data, (TimeSeries, MultiTimeSeries)): # Convert numpy arrays based on dimensionality if isinstance(feat_data, np.ndarray): if feat_data.ndim == 1: # 1D array -> TimeSeries dynamic_features[feat_id] = TimeSeries(feat_data, name=feat_id) elif feat_data.ndim == 2: # 2D array -> MultiTimeSeries (each row is a component) ts_list = [ TimeSeries(feat_data[i, :], discrete=False, name=f"{feat_id}_{i}") for i in range(feat_data.shape[0]) ] dynamic_features[feat_id] = MultiTimeSeries(ts_list, name=feat_id) else: raise ValueError( f"Feature {feat_id} has unsupported dimensionality: {feat_data.ndim}D" ) else: # Assume it's 1D data if not numpy array dynamic_features[feat_id] = TimeSeries(feat_data, name=feat_id) self.n_cells = calcium.shape[0] self.n_frames = calcium.shape[1] self.neurons = [] if self.verbose: print("Building neurons...") # Extract static feature parameters t_rise = static_features.get("t_rise_sec") t_off = static_features.get("t_off_sec") fps = static_features.get("fps") for i in tqdm.tqdm( np.arange(self.n_cells), position=0, leave=True, disable=not self.verbose ): # Build neuron kwargs for new optional data neuron_kwargs = { 'default_t_rise': t_rise, 'default_t_off': t_off, 'fps': fps, 'optimize_kinetics': optimize_kinetics, } if asp is not None: neuron_kwargs['asp'] = asp[i, :] if reconstructions is not None: neuron_kwargs['reconstructed'] = reconstructions[i, :] if neuron_metrics_df is not None: neuron_kwargs['metrics'] = {k: v[i] for k, v in neuron_metrics_df.items()} cell = Neuron( str(i), calcium[i, :], spikes[i, :] if spikes is not None else None, **neuron_kwargs ) self.neurons.append(cell) # Now create MultiTimeSeries from neurons to preserve their shuffle masks calcium_ts_list = [neuron.ca for neuron in self.neurons] spikes_ts_list = [ ( neuron.sp if neuron.sp is not None else TimeSeries(np.zeros(self.n_frames), discrete=True, name=f"neuron_{i}_sp_zero") ) for i, neuron in enumerate(self.neurons) ] # Create MultiTimeSeries from the TimeSeries objects in neurons # This preserves the individual shuffle masks created by each Neuron self.calcium = MultiTimeSeries(calcium_ts_list, name="calcium") # Allow zero columns for spikes since many neurons might not spike self.spikes = MultiTimeSeries(spikes_ts_list, allow_zero_columns=True, name="spikes") self.dynamic_features = dynamic_features # Create _2d versions of circular features for MI computation if create_circular_2d: self._create_circular_2d_features(verbose=self.verbose) # Set shuffle masks on dynamic features to exclude near-zero shifts min_feat_shift = int(MIN_FEAT_SHIFT_SEC * fps) for feat_ts in self.dynamic_features.values(): if isinstance(feat_ts, (TimeSeries, MultiTimeSeries)): feat_ts.shuffle_mask[:min_feat_shift] = False feat_ts.shuffle_mask[-min_feat_shift:] = False # Store experiment-level metadata self.metadata = metadata # Check for protected attributes in dynamic features conflicting_features = [ feat_id for feat_id in dynamic_features if isinstance(feat_id, str) and feat_id in self._PROTECTED_ATTRS ] if conflicting_features: raise ValueError( f"Dynamic feature names conflict with protected attributes: {conflicting_features}. " f"Protected attributes are: {sorted(self._PROTECTED_ATTRS)}" ) # Store static_features as an attribute for consistency self.static_features = static_features # Set dynamic features as attributes for feat_id in dynamic_features: if isinstance(feat_id, str): if hasattr(self, feat_id): warnings.warn(f"Feature name '{feat_id}' overwrites existing attribute.") setattr(self, feat_id, dynamic_features[feat_id]) # Skip tuples (multifeatures) as they can't be attribute names # Also set static features as individual attributes for backward compatibility for sfeat_name in static_features: if sfeat_name in self._PROTECTED_ATTRS: warnings.warn( f"Static feature name '{sfeat_name}' conflicts with protected attribute. " f"Access via attribute name with underscore: _{sfeat_name}" ) setattr(self, f"_{sfeat_name}", static_features[sfeat_name]) else: setattr(self, sfeat_name, static_features[sfeat_name]) # for selectivity data from INTENSE self.stats_tables = {} self.significance_tables = {} self.selectivity_tables_initialized = False # for dimensionality reduction embeddings self.embeddings = {"calcium": {}, "spikes": {}} # Cache for RDM computations self._rdm_cache = {} if self.verbose: print("Building data hashes...") self._build_data_hashes(mode="calcium") # Only build spike hashes if we have actual spike data (not just zeros) # Check if any neuron has non-None spikes has_spikes = any(neuron.sp is not None for neuron in self.neurons) if has_spikes: self._build_data_hashes(mode="spikes") if self.verbose: print("Final checkpoint...") self._checkpoint() # self._load_precomputed_data(**kwargs) if self.verbose: print( f'Experiment "{self.signature}" constructed successfully with {self.n_cells} neurons and {len(self.dynamic_features)} features' )
[docs] def check_ds(self, ds): """Check if downsampling rate is appropriate for behavior analysis. Validates that the downsampling rate won't cause time gaps larger than the minimum behavior time interval (0.25 seconds), which could lead to missed behavioral events. Parameters ---------- ds : int Downsampling factor. The data will be sampled every ds frames. Must be a positive integer (>= 1). Returns ------- None Raises ------ ValueError If fps (frames per second) is not set for this experiment, or if ds is less than 1. Error messages include relevant context. TypeError If ds is not an integer. Warnings -------- UserWarning Issued if the time gap created by downsampling exceeds DEFAULT_MIN_BEHAVIOUR_TIME (0.25 seconds). The warning includes the current threshold, downsampling factor, and resulting time gap. Notes ----- The time gap is calculated as (1/fps) * ds seconds. For example, with fps=20 and ds=10, the time gap would be 0.5 seconds, which exceeds the 0.25 second threshold and triggers a warning.""" # Validate ds parameter if not isinstance(ds, (int, np.integer)): raise TypeError(f"Downsampling factor ds must be an integer, got {type(ds).__name__}") if ds < 1: raise ValueError(f"Downsampling factor ds must be >= 1, got {ds}") if not hasattr(self, "fps"): raise ValueError(f"fps not set for {self.signature}") time_step = 1.0 / self.fps if time_step * ds > DEFAULT_MIN_BEHAVIOUR_TIME: warnings.warn( f"Downsampling constant is too high: some behaviour acts may be skipped. " f"Current minimal behaviour time interval is set to {DEFAULT_MIN_BEHAVIOUR_TIME} sec, " f"downsampling {ds} will create time gaps of {time_step*ds:.3f} sec", UserWarning, )
def _set_selectivity_tables(self, mode, fbunch=None, cbunch=None): """Create or reset selectivity statistics tables for the specified mode. Creates nested dictionaries for storing mutual information statistics. Overwrites any existing tables for the given mode. Parameters ---------- mode : str Table identifier (typically 'calcium' or 'spikes'). Not validated. fbunch : None, str, or iterable of str, optional Feature(s) to include. If None, includes all dynamic features. cbunch : None, int, or iterable of int, optional Cell ID(s) to include. If None, includes all cells. Notes ----- Creates two nested dictionaries {feature: {cell: dict}}: - self.stats_tables[mode]: MI statistics (initialized from DEFAULT_STATS) - self.significance_tables[mode]: Significance data (from DEFAULT_SIGNIFICANCE) Sets self.selectivity_tables_initialized to True. Warning: Overwrites existing tables without preserving data.""" # neuron-feature pair statistics stats_table = self._populate_cell_feat_dict(DEFAULT_STATS, fbunch=fbunch, cbunch=cbunch) # neuron-feature pair significance-related data significance_table = self._populate_cell_feat_dict( DEFAULT_SIGNIFICANCE, fbunch=fbunch, cbunch=cbunch ) self.stats_tables[mode] = stats_table self.significance_tables[mode] = significance_table self.selectivity_tables_initialized = True def _build_pair_hash(self, cell_id, feat_id, mode="calcium"): """Build a unique hash representation of neuron-feature pair data. Creates a hash tuple that uniquely identifies the combination of neural activity data and feature data for caching computations. Parameters ---------- cell_id : int Neuron index. Must exist in self.neurons list. feat_id : str or iterable of str Feature name(s). Single-element iterables are converted to strings. All features must exist in self.dynamic_features. mode : {'calcium', 'spikes'}, optional Type of neural activity data. Default is 'calcium'. Returns ------- tuple Hash tuple containing: - Single feature: (activity_hash, feature_hash) - Multiple features: (activity_hash, feature1_hash, feature2_hash, ...) - Empty iterable: (activity_hash,) Raises ------ ValueError If mode is not 'calcium' or 'spikes'. KeyError If cell_id not in self.neurons or feat_id not in self.dynamic_features. AttributeError If neuron lacks .ca/.sp attributes or feature lacks .data attribute. Notes ----- Uses SHA256 hashing on raw numpy arrays. Multiple features are sorted alphabetically before hashing to ensure order-independent results.""" if mode == "calcium": act = self.neurons[cell_id].ca.data elif mode == "spikes": act = self.neurons[cell_id].sp.data else: raise ValueError('"mode" can be either "calcium" or "spikes"') act_hash = get_hash(act) if (not isinstance(feat_id, str)) and len(feat_id) == 1: feat_id = feat_id[0] if isinstance(feat_id, str): dyn_data = self.dynamic_features[feat_id].data dyn_data_hash = get_hash(dyn_data) pair_hash = (act_hash, dyn_data_hash) else: ordered_fnames = tuple(sorted(list(feat_id))) list_of_hashes = [act_hash] for fname in ordered_fnames: dyn_data = self.dynamic_features[fname].data dyn_data_hash = get_hash(dyn_data) list_of_hashes.append(dyn_data_hash) pair_hash = tuple(list_of_hashes) return pair_hash def _build_data_hashes(self, mode="calcium"): """Build hash representations for all neuron-feature pairs. Pre-computes SHA256 hashes for all combinations of neurons and features to enable efficient caching of mutual information calculations. Parameters ---------- mode : {'calcium', 'spikes'}, optional Type of neural activity data. Default is 'calcium'. Not validated. Notes ----- Creates nested dictionary structure: self._data_hashes[mode][feat_id][cell_id] = hash_tuple Where hash_tuple is from _build_pair_hash(cell_id, feat_id, mode). Warning: Calling this method multiple times for the same mode will completely recreate all hashes, overwriting existing data.""" # Create default hashes structure for this mode only if not hasattr(self, "_data_hashes"): self._data_hashes = {} # Create independent dictionary for each mode to avoid aliasing mode_hashes = { dfeat: dict(zip(range(self.n_cells), [None for _ in range(self.n_cells)])) for dfeat in self.dynamic_features.keys() } self._data_hashes[mode] = mode_hashes # Populate hashes for this mode for feat_id in self.dynamic_features: for cell_id in range(self.n_cells): self._data_hashes[mode][feat_id][cell_id] = self._build_pair_hash( cell_id, feat_id, mode=mode )
[docs] def add_feature(self, name, data, ts_type=None, **ts_kwargs): """Add a dynamic feature to the experiment after initialization. Mirrors the feature registration performed during ``__init__``: wraps raw arrays, validates length, applies shuffle masks, stores in ``dynamic_features`` dict, and sets a direct attribute. Parameters ---------- name : str Feature name. Must not conflict with protected attributes. data : array_like or TimeSeries or MultiTimeSeries Feature data. Length must match ``n_frames``. Raw arrays are wrapped in TimeSeries automatically. ts_type : str or None, optional Passed to ``TimeSeries(ts_type=...)`` when wrapping raw arrays. Common values: ``'discrete'``, ``'circular'``. Ignored if *data* is already a TimeSeries/MultiTimeSeries. **ts_kwargs Additional keyword arguments forwarded to the TimeSeries constructor (e.g. ``circular_period``). """ if not isinstance(name, str): raise TypeError(f"Feature name must be a string, got {type(name).__name__}") if name in self._PROTECTED_ATTRS: raise ValueError( f"Feature name '{name}' conflicts with a protected attribute. " f"Protected attributes are: {sorted(self._PROTECTED_ATTRS)}" ) # Wrap raw arrays if isinstance(data, (TimeSeries, MultiTimeSeries)): ts = data elif isinstance(data, np.ndarray): if data.ndim == 1: ts = TimeSeries(data, name=name, ts_type=ts_type, **ts_kwargs) elif data.ndim == 2: ts_list = [ TimeSeries(data[i, :], discrete=False, name=f"{name}_{i}") for i in range(data.shape[0]) ] ts = MultiTimeSeries(ts_list, name=name) else: raise ValueError( f"Feature '{name}' has unsupported dimensionality: {data.ndim}D" ) else: ts = TimeSeries(np.asarray(data), name=name, ts_type=ts_type, **ts_kwargs) # Validate length n = ts.data.shape[1] if isinstance(ts, MultiTimeSeries) else ts.data.shape[-1] if n != self.n_frames: raise ValueError( f"Feature '{name}' has {n} timepoints but experiment has {self.n_frames}" ) # Apply shuffle mask fps = getattr(self, "fps", None) if fps is not None: min_feat_shift = int(MIN_FEAT_SHIFT_SEC * fps) ts.shuffle_mask[:min_feat_shift] = False ts.shuffle_mask[-min_feat_shift:] = False self.dynamic_features[name] = ts setattr(self, name, ts) self._build_data_hashes(mode="calcium") warnings.warn( f"Feature '{name}' added after initialization. " "Cached INTENSE stats will not reflect this feature; re-run analysis if needed.", stacklevel=2, )
def _create_circular_2d_features(self, verbose=True): """Create _2d (cos, sin) versions of circular features during __init__. Called automatically during construction if create_circular_2d=True. This is safe because no caching has occurred yet. Parameters ---------- verbose : bool, optional Print information about created features. Default is True. Returns ------- list of tuple List of (original_name, name_2d, period) for each created feature. """ from ..information.circular_transform import circular_to_cos_sin circular_2d_created = [] for f, ts in list(self.dynamic_features.items()): # Skip if already MultiTimeSeries or already a _2d feature if isinstance(ts, MultiTimeSeries) or f.endswith("_2d"): continue # Check if continuous AND circular (discrete features should not be transformed) if (isinstance(ts, TimeSeries) and not ts.discrete and hasattr(ts, "type_info") and ts.type_info and ts.type_info.is_circular): # Create _2d version (cos, sin) MultiTimeSeries name_2d = f"{f}_2d" period = ts.type_info.circular_period transformed = circular_to_cos_sin(ts.data, period=period, name=name_2d) self.dynamic_features[name_2d] = transformed circular_2d_created.append((f, name_2d, period)) if verbose and circular_2d_created: print("Circular features -> _2d versions created:") for orig, name_2d, period in circular_2d_created: period_str = "2pi" if period and abs(period - 2 * np.pi) < 0.1 else f"{period}" print(f" '{orig}' (period={period_str}) -> '{name_2d}' [cos, sin]") return circular_2d_created
[docs] def get_circular_2d_feature(self, name): """Get the _2d (cos, sin) version of a circular feature. Parameters ---------- name : str Original circular feature name (without _2d suffix). Returns ------- MultiTimeSeries or None The _2d version, or None if not available. Examples -------- >>> # Assuming exp has circular feature 'headdirection' >>> mts = exp.get_circular_2d_feature('headdirection') # doctest: +SKIP >>> mts.n_dim # doctest: +SKIP 2 """ name_2d = f"{name}_2d" return self.dynamic_features.get(name_2d)
[docs] def has_circular_2d(self, name): """Check if a feature has a _2d (cos, sin) counterpart. Parameters ---------- name : str Feature name to check. Returns ------- bool True if the _2d version exists. """ return f"{name}_2d" in self.dynamic_features
[docs] def get_circular_features(self): """Get all circular features (excluding _2d versions). Returns ------- dict Dictionary mapping feature names to TimeSeries objects for all circular features. """ circular = {} for name, ts in self.dynamic_features.items(): if name.endswith("_2d"): continue if isinstance(ts, TimeSeries) and ts.type_info and ts.type_info.is_circular: circular[name] = ts return circular
def _trim_data(self, calcium, spikes, dynamic_features, bad_frames_mask, force_filter=False): """Filter out bad frames from all data arrays. Removes frames marked as bad in bad_frames_mask from calcium, spikes, and all dynamic features, maintaining temporal alignment. Parameters ---------- calcium : numpy.ndarray Calcium data, shape (n_neurons, n_frames). spikes : numpy.ndarray or None Spike data, same shape as calcium. Can be None. dynamic_features : dict Time-varying features as arrays, TimeSeries, or MultiTimeSeries. bad_frames_mask : array-like of bool Boolean mask where True indicates bad frames to remove. force_filter : bool, optional Force re-filtering even if already filtered. Default False. Returns ------- tuple (filtered_calcium, filtered_spikes, filtered_dynamic_features) Raises ------ AttributeError If data already filtered and force_filter is False. Side Effects ------------ Sets self.filtered_flag = True and self.bad_frames_mask. Notes ----- For multi-dimensional arrays, assumes time is the second dimension. For 1D arrays or unknown types, assumes time is the last dimension.""" if not force_filter and self.filtered_flag: raise AttributeError( 'Data is already filtered, if you want to force filtering it again, set "force_filter = True"' ) f_calcium = calcium[:, ~bad_frames_mask] if spikes is not None: f_spikes = spikes[:, ~bad_frames_mask] else: f_spikes = None f_dynamic_features = {} for feat_id in dynamic_features: current_ts = dynamic_features[feat_id] if isinstance(current_ts, TimeSeries): # Preserve ts_type to maintain circular/linear type info ts_type = current_ts.type_info if hasattr(current_ts, 'type_info') else None f_ts = TimeSeries( current_ts.data[~bad_frames_mask], discrete=current_ts.discrete, ts_type=ts_type, name=current_ts.name if hasattr(current_ts, 'name') else feat_id ) elif isinstance(current_ts, MultiTimeSeries): # Handle MultiTimeSeries by trimming each component filtered_components = [] for i in range(current_ts.n_dim): component_data = current_ts.data[i, ~bad_frames_mask] # Preserve component name and type_info component_name = None component_type = None if hasattr(current_ts, 'ts_list') and i < len(current_ts.ts_list): orig_component = current_ts.ts_list[i] component_name = orig_component.name if hasattr(orig_component, 'name') else f"{feat_id}_{i}" component_type = orig_component.type_info if hasattr(orig_component, 'type_info') else None else: component_name = f"{feat_id}_{i}" filtered_components.append( TimeSeries(component_data, discrete=current_ts.discrete, ts_type=component_type, name=component_name) ) # Preserve parent name if it exists parent_name = current_ts.name if hasattr(current_ts, 'name') else feat_id f_ts = MultiTimeSeries(filtered_components, name=parent_name) elif isinstance(current_ts, np.ndarray): # Handle raw arrays if current_ts.ndim == 1: f_ts = TimeSeries(current_ts[~bad_frames_mask], name=feat_id) else: # Multi-dimensional array f_ts = current_ts[:, ~bad_frames_mask] else: # Fallback for other types f_ts = TimeSeries(current_ts[~bad_frames_mask], name=feat_id) f_dynamic_features[feat_id] = f_ts self.filtered_flag = True self.bad_frames_mask = bad_frames_mask return f_calcium, f_spikes, f_dynamic_features def _checkpoint(self): """Validate experiment data integrity and consistency. Performs comprehensive checks to ensure the experiment data is properly formatted and meets minimum requirements for analysis. Raises ------ ValueError If any of the following conditions are met: - Signal is too short for shuffle mask creation - Number of cells exceeds number of time frames (likely transposed) - Feature shapes are inconsistent with experiment duration Notes ----- Checks include: - Minimum signal length based on decay time and shuffle requirements - Proper data orientation (neurons × timepoints) - Consistency of all feature dimensions with n_frames""" # Check minimal length for proper shuffle mask creation t_off_sec = getattr(self, "t_off_sec", DEFAULT_T_OFF) fps = getattr(self, "fps", DEFAULT_FPS) t_off_frames = t_off_sec * fps min_required_frames = ( int(t_off_frames * CA_SHIFT_N_TOFF * 2) + 10 ) # Need space for both ends + some valid positions if self.n_frames < min_required_frames: raise ValueError( f"Signal too short: {self.n_frames} frames. " f"Minimum {min_required_frames} frames required for proper shuffle mask creation " f"(based on t_off={t_off_sec}s, fps={fps}Hz)." ) if self.n_cells > self.n_frames: raise ValueError( f"Number of cells ({self.n_cells}) > number of time frames ({self.n_frames}). " f"Data appears to be transposed. Expected shape: (n_neurons, n_timepoints)." ) for dfeat in ["calcium", "spikes"]: if self.n_frames not in getattr(self, dfeat).shape: raise ValueError( f'"{dfeat}" feature has inappropriate shape: {getattr(self, dfeat).data.shape}' f"inconsistent with data length {self.n_frames}" ) for dfeat in self.dynamic_features.keys(): if isinstance(dfeat, str): if self.n_frames not in getattr(self, dfeat).data.shape: raise ValueError( f'"{dfeat}" feature has inappropriate shape: {getattr(self, dfeat).data.shape}' f"inconsistent with data length {self.n_frames}" ) else: # For tuple features (multifeatures), check the underlying data feat_data = self.dynamic_features[dfeat] if hasattr(feat_data, "data") and self.n_frames not in feat_data.data.shape: raise ValueError( f'"{dfeat}" feature has inappropriate shape: {feat_data.data.shape}' f"inconsistent with data length {self.n_frames}" ) def _populate_cell_feat_dict(self, content, fbunch=None, cbunch=None): """Create nested dictionary structure for cell-feature pairs. Builds a two-level dictionary where the outer level contains features and the inner level contains cells, with each entry initialized to the specified content value. Parameters ---------- content : any Default value to populate in each cell-feature entry. fbunch : None, str, or iterable of str, optional Feature(s) to include. If None, includes all features. cbunch : None, int, or iterable of int, optional Cell ID(s) to include. If None, includes all cells. Returns ------- dict Nested dictionary with structure: {feature_id: {cell_id: content}}""" cell_ids = self._process_cbunch(cbunch) feat_ids = self._process_fbunch(fbunch, allow_multifeatures=True) nested_dict = populate_nested_dict(content, feat_ids, cell_ids) return nested_dict def _process_cbunch(self, cbunch): """Convert cell specification to list of cell IDs. Parameters ---------- cbunch : None, int, or iterable of int Cell specification: - None: returns all cell IDs - int: returns single cell ID in a list - iterable: returns list of specified cell IDs Returns ------- list of int List of cell IDs to process.""" if isinstance(cbunch, int): cell_ids = [cbunch] elif cbunch is None: cell_ids = list(np.arange(self.n_cells)) else: cell_ids = list(cbunch) return cell_ids def _process_fbunch(self, fbunch, allow_multifeatures=False, mode="calcium"): """Convert feature specification to list of feature IDs. Parameters ---------- fbunch : None, str, iterable of str, or iterable of tuples Feature specification: - None: returns all feature IDs - str: returns single feature ID in a list - iterable of str: returns list of feature IDs - iterable of tuples: multi-feature combinations (if allowed) allow_multifeatures : bool, optional Whether to allow multi-feature tuples. Default is False. mode : {'calcium', 'spikes'}, optional Activity mode for filtering relevant features. Default is 'calcium'. Returns ------- list List of feature IDs or feature ID tuples to process. Raises ------ ValueError If multi-features are provided but not allowed.""" if isinstance(fbunch, str): feat_ids = [fbunch] elif fbunch is None: # default set of features if allow_multifeatures: try: # stats table contains up-to-date set of features, including multifeatures feat_ids = list(self.stats_tables[mode].keys()) except KeyError: # if stats is not available, take pre-defined full set of features feat_ids = list(self.dynamic_features.keys()) else: feat_ids = list(self.dynamic_features.keys()) else: feat_ids = [] # check for multifeatures for fname in fbunch: if isinstance(fname, str): if fname in self.dynamic_features: feat_ids.append(fname) else: if allow_multifeatures: feat_ids.append(tuple(sorted(list(fname)))) else: raise ValueError( 'Multifeature detected in "allow_multifeatures=False" mode' ) return feat_ids def _process_sbunch(self, sbunch, significance_mode=False): """Process statistics bunch specification into filtered list. Converts input formats for specifying statistics types into a standardized list, filtering out any invalid entries. Parameters ---------- sbunch : None, str, or iterable of str Statistics specification. If None, returns all valid stats. Invalid entries in iterables are silently filtered out. significance_mode : bool, optional If True, uses SIGNIFICANCE_VARS, else uses STATS_VARS. Default False. Returns ------- list of str Valid statistics variable names only. Notes ----- Single strings are returned as-is without validation. Iterables have invalid entries filtered out silently.""" if significance_mode: default_list = SIGNIFICANCE_VARS else: default_list = STATS_VARS if isinstance(sbunch, str): return [sbunch] elif sbunch is None: return default_list else: return [st for st in sbunch if st in default_list] def _add_single_feature_to_data_hashes(self, feat_id, mode="calcium"): """Add hash mapping for a single feature. Creates hash representations for all neuron-feature pairs for the specified feature and adds them to the data hashes structure. Parameters ---------- feat_id : str Feature identifier. Must exist in self.dynamic_features. mode : {'calcium', 'spikes'}, optional Neural activity type. Default 'calcium'. Side Effects ------------ Modifies self._data_hashes[mode][feat_id] in place. Notes ----- Only adds if feature not already in data hashes. No validation performed on feat_id existence.""" if feat_id not in self._data_hashes[mode]: self._data_hashes[mode][feat_id] = {} # Add hashes for all cells for cell_id in range(self.n_cells): pair_hash = self._build_pair_hash(cell_id, feat_id, mode=mode) self._data_hashes[mode][feat_id][cell_id] = pair_hash def _add_single_feature_to_stats(self, feat_id, mode="calcium"): """Add empty stats and significance tables for a single feature. Initializes the statistics and significance tracking structures for all neurons for the specified feature. Parameters ---------- feat_id : str Feature identifier to add tables for. mode : {'calcium', 'spikes'}, optional Neural activity type. Default 'calcium'. Side Effects ------------ Modifies self.stats_tables[mode][feat_id] and self.significance_tables[mode][feat_id] in place. Notes ----- Only adds if feature not already in stats tables. Creates deep copies of DEFAULT_STATS and DEFAULT_SIGNIFICANCE for each cell to prevent aliasing.""" if feat_id not in self.stats_tables[mode]: # Initialize stats for all cells self.stats_tables[mode][feat_id] = {} self.significance_tables[mode][feat_id] = {} for cell_id in range(self.n_cells): self.stats_tables[mode][feat_id][cell_id] = DEFAULT_STATS.copy() self.significance_tables[mode][feat_id][cell_id] = DEFAULT_SIGNIFICANCE.copy() def _add_multifeature_to_data_hashes(self, feat_id, mode="calcium"): """Add hash mapping for a multi-feature combination. Creates hash representations for the specified multi-feature combination across all neurons for joint mutual information calculations. .. deprecated:: The multifeature mechanism using tuples is marked for deprecation. Future versions will use a different approach for joint MI. Parameters ---------- feat_id : list or tuple of str Multiple feature names (at least 2). Must not be a string or single-element collection. mode : {'calcium', 'spikes'}, optional Neural activity type. Default 'calcium'. Side Effects ------------ Modifies self._data_hashes[mode] by adding sorted tuple key. Raises ------ ValueError If feat_id is a string or single-element collection. Notes ----- Multi-features are stored as sorted tuples. Existing entries ignored. The mode parameter is now properly passed to _build_pair_hash.""" if isinstance(feat_id, str): raise ValueError( "This method is for multifeature update only. Use _add_single_feature_to_data_hashes for single features." ) if len(feat_id) == 1: raise ValueError( f"Single-element list {feat_id} provided. " "Use _add_single_feature_to_data_hashes for single features or provide multiple features." ) ordered_fnames = tuple(sorted(list(feat_id))) if ordered_fnames not in self._data_hashes[mode]: all_hashes = [ self._build_pair_hash(cell_id, ordered_fnames, mode=mode) for cell_id in range(self.n_cells) ] new_dict = {ordered_fnames: dict(zip(range(self.n_cells), all_hashes))} self._data_hashes[mode].update(new_dict) def _add_multifeature_to_stats(self, feat_id, mode="calcium"): """Add empty stats and significance tables for a multi-feature combination. Initializes the statistics and significance tracking structures for all neurons for the specified multi-feature combination. .. deprecated:: The multifeature mechanism using tuples is marked for deprecation. Future versions will use a different approach for joint MI. Parameters ---------- feat_id : list or tuple of str Multiple feature names (at least 2). Must not be a string or single-element collection. mode : {'calcium', 'spikes'}, optional Neural activity type. Default 'calcium'. Side Effects ------------ - Modifies self.stats_tables[mode] by adding sorted tuple key - Modifies self.significance_tables[mode] by adding sorted tuple key - Prints to stdout if self.verbose=True and feature is new Raises ------ ValueError If feat_id is a string or single-element collection. Notes ----- Multi-features are normalized to sorted tuples. Only prints verbose message for new features using the sorted tuple representation.""" if isinstance(feat_id, str): raise ValueError( "This method is for multifeature update only. Use _add_single_feature_to_stats for single features." ) if len(feat_id) == 1: raise ValueError( f"Single-element list {feat_id} provided. " "Use _add_single_feature_to_stats for single features or provide multiple features." ) ordered_fnames = tuple(sorted(list(feat_id))) if ordered_fnames not in self.stats_tables[mode]: if self.verbose: print(f"Multifeature {ordered_fnames} is new, it will be added to stats table") self.stats_tables[mode][ordered_fnames] = { cell_id: DEFAULT_STATS.copy() for cell_id in range(self.n_cells) } self.significance_tables[mode][ordered_fnames] = { cell_id: DEFAULT_SIGNIFICANCE.copy() for cell_id in range(self.n_cells) } def _check_stats_relevance(self, cell_id, feat_id, mode="calcium"): """Check if stats exist and are current, adding new features if needed. Verifies if statistics for a neuron-feature pair exist and match current data hashes. Can add new features to tables with deprecation warnings. .. note:: This method has side effects - it can add features to stats_tables, significance_tables, and _data_hashes. This behavior will be removed after the tuple multifeature mechanism is fully deprecated. Parameters ---------- cell_id : int Neuron index. feat_id : str or tuple of str Single feature or tuple of features for joint MI. mode : {'calcium', 'spikes'}, optional Neural activity type. Default 'calcium'. Returns ------- bool True if stats exist/were added and hashes match, False if data changed. Side Effects ------------ May add features to stats_tables, significance_tables, and _data_hashes. Issues deprecation warnings for dynamic feature additions. Raises ------ ValueError If single feature not in self.dynamic_features.""" if not isinstance(feat_id, str): feat_id = tuple(sorted(list(feat_id))) if feat_id not in self.stats_tables[mode]: # DEPRECATED: Dynamic stats table updates are discouraged due to statistical validity concerns # Multiple comparison correction depends on the total number of hypotheses tested import warnings if isinstance(feat_id, tuple): warnings.warn( "DEPRECATED: Adding tuple features to stats table after initialization is discouraged. " "This can invalidate multiple comparison corrections. " "Consider using use_precomputed_stats=False for proper statistical analysis.", DeprecationWarning, stacklevel=2, ) # Still allow it for backward compatibility self._add_multifeature_to_stats(feat_id, mode=mode) if feat_id not in self._data_hashes[mode]: self._add_multifeature_to_data_hashes(feat_id, mode=mode) return True else: # Single feature case - this shouldn't happen if experiment is properly initialized if feat_id not in self.dynamic_features: raise ValueError( f"Feature {feat_id} is not present in dynamic_features. " "Check the feature name." ) warnings.warn( f"Feature {feat_id} was not included in initial stats computation. " "This suggests the experiment was not properly initialized with all features. " "Adding it now may invalidate multiple comparison corrections.", UserWarning, stacklevel=2, ) # Add new feature to stats table self._add_single_feature_to_stats(feat_id, mode=mode) # Also ensure it's in data hashes if feat_id not in self._data_hashes[mode]: self._add_single_feature_to_data_hashes(feat_id, mode=mode) return True pair_hash = self._data_hashes[mode][feat_id][cell_id] existing_hash = self.stats_tables[mode][feat_id][cell_id]["data_hash"] # if (stats does not exist yet) or (stats exists and data is the same): if (existing_hash is None) or (pair_hash == existing_hash): return True else: if self.verbose: print( f"Looks like the data for the pair (cell {cell_id}, feature {feat_id}) " "has been changed since the last calculation)" ) return False def _update_stats_and_significance(self, stats, mode, cell_id, feat_id, stage2_only): """ Updates stats table and linked significance table to erase irrelevant data properly. Parameters ---------- stats : dict Statistics dictionary to update the table with mode : str Mode of data processing (e.g., 'calcium') cell_id : int Cell identifier feat_id : str or tuple Feature identifier or tuple of feature identifiers stage2_only : bool If True, only update stage 2 significance data """ # update statistics self.stats_tables[mode][feat_id][cell_id].update(stats) if not stage2_only: # erase significance data completely since stats for stage 1 has been modified self.significance_tables[mode][feat_id][cell_id].update(DEFAULT_SIGNIFICANCE.copy()) else: # erase significance data for stage 2 since stats for stage 2 has been modified self.significance_tables[mode][feat_id][cell_id].update( {"stage2": None, "shuffles2": None} )
[docs] def update_neuron_feature_pair_stats( self, stats, cell_id, feat_id, mode="calcium", force_update=False, stage2_only=False, ): """ Updates calcium-feature pair statistics. feat_id should be a string or an iterable of strings (in case of joint MI calculation). This function allows multifeatures. Parameters ---------- stats : dict Statistics dictionary containing the updates cell_id : int Cell identifier feat_id : str or iterable of str Feature identifier(s). Can be a string or iterable of strings for joint MI calculation mode : str, optional Data processing mode. Default is "calcium" force_update : bool, optional If True, force update even if data hashes match. Default is False stage2_only : bool, optional If True, only update stage 2 statistics. Default is False """ if not isinstance(feat_id, str): self._add_multifeature_to_data_hashes(feat_id, mode=mode) self._add_multifeature_to_stats(feat_id, mode=mode) if self._check_stats_relevance(cell_id, feat_id, mode=mode): self._update_stats_and_significance( stats, mode, cell_id, feat_id, stage2_only=stage2_only ) else: if not force_update: if self.verbose: print('To forcefully update the stats, set "force_update=True"') else: self._update_stats_and_significance( stats, mode, cell_id, feat_id, stage2_only=stage2_only )
[docs] def update_neuron_feature_pair_significance(self, sig, cell_id, feat_id, mode="calcium"): """ Updates calcium-feature pair significance data. feat_id should be a string or an iterable of strings (in case of joint MI calculation). This function allows multifeatures. Parameters ---------- sig : dict Significance data to update cell_id : int Cell identifier feat_id : str or iterable of str Feature identifier(s). Can be a string or iterable of strings for joint MI calculation mode : str, optional Data processing mode. Default is "calcium" """ if not isinstance(feat_id, str): self._add_multifeature_to_data_hashes(feat_id, mode=mode) self._add_multifeature_to_stats(feat_id, mode=mode) if self._check_stats_relevance(cell_id, feat_id, mode=mode): self.significance_tables[mode][feat_id][cell_id].update(sig) else: raise ValueError( "Can not update significance table until the collision between actual data hashes and " "saved stats data hashes is resolved. Use update_neuron_feature_pair_stats" 'with "force_update=True" to forcefully rewrite statistics' )
[docs] def get_neuron_feature_pair_stats(self, cell_id, feat_id, mode="calcium"): """Get selectivity statistics for a neuron-feature pair. Retrieves pre-computed statistics measuring the relationship between neural activity and behavioral/experimental features. Supports both single features and multi-feature analysis. Parameters ---------- cell_id : int Neuron/cell identifier. feat_id : str or tuple of str Feature identifier(s). Can be a single feature name or tuple of feature names for joint analysis. mode : {'calcium', 'spikes'}, optional Type of neural activity. Default is 'calcium'. Returns ------- dict or None Dictionary containing various statistical measures of the neuron-feature relationship. Returns None if statistics have not been computed or if data has changed since computation. Notes ----- Statistics must be pre-computed using the selectivity analysis pipeline. The method checks data integrity using hashes to ensure statistics are up-to-date with the current data. .. note:: Despite the 'get' name, this method can trigger side effects via _check_stats_relevance which may add new features to tables. This behavior will be removed after tuple multifeature deprecation.""" stats = None if self._check_stats_relevance(cell_id, feat_id, mode=mode): stats = self.stats_tables[mode][feat_id][cell_id] else: if self.verbose: print("Consider recalculating stats") return stats
[docs] def get_neuron_feature_pair_significance(self, cell_id, feat_id, mode="calcium"): """Get statistical significance data for a neuron-feature pair. Retrieves significance testing results for the neuron-feature relationship, typically from shuffle-based permutation tests. Parameters ---------- cell_id : int Neuron/cell identifier. feat_id : str or tuple of str Feature identifier(s). Can be a single feature name or tuple of feature names for joint analysis. mode : {'calcium', 'spikes'}, optional Type of neural activity. Default is 'calcium'. Returns ------- dict or None Dictionary containing significance test results including p-values, shuffle distributions, and multiple comparison corrections. Returns None if significance has not been computed or if underlying statistics are outdated. Notes ----- Significance testing typically uses shuffle tests where temporal relationships are destroyed while preserving marginal distributions. Results include both single-stage and two-stage testing procedures. .. note:: Despite the 'get' name, this method can trigger side effects via _check_stats_relevance which may add new features to tables. This behavior will be removed after tuple multifeature deprecation.""" sig = None if self._check_stats_relevance(cell_id, feat_id, mode=mode): sig = self.significance_tables[mode][feat_id][cell_id] else: if self.verbose: print("Consider recalculating stats") return sig
[docs] def get_multicell_shuffled_calcium( self, cbunch=None, method="roll_based", return_array=True, **kwargs ): """ Get shuffled calcium data for multiple cells. Parameters ---------- cbunch : int, list, or None Cell indices. If None, all cells are used. method : {'roll_based', 'waveform_based', 'chunks_based'}, default='roll_based' Shuffling method to use return_array : bool, default=True If True, return numpy array. If False, return MultiTimeSeries object. **kwargs Additional parameters passed to the shuffling method Returns ------- np.ndarray or MultiTimeSeries If return_array=True: Shuffled calcium data with shape (n_cells, n_frames) If return_array=False: MultiTimeSeries object containing shuffled data""" # Validate method valid_methods = ["roll_based", "waveform_based", "chunks_based"] if method not in valid_methods: raise ValueError( f"Invalid shuffling method '{method}'. Must be one of: {valid_methods}" ) cell_list = self._process_cbunch(cbunch) # Validate cell indices if any(idx < 0 or idx >= self.n_cells for idx in cell_list): raise ValueError(f"Invalid cell indices. Must be between 0 and {self.n_cells-1}") if return_array: agg_sh_data = np.zeros((len(cell_list), self.n_frames)) for i, cell_idx in enumerate(cell_list): cell = self.neurons[cell_idx] sh_data = cell.get_shuffled_calcium(method=method, return_array=True, **kwargs) agg_sh_data[i, :] = sh_data return agg_sh_data else: # Return MultiTimeSeries object ts_list = [] for cell_idx in cell_list: cell = self.neurons[cell_idx] sh_ts = cell.get_shuffled_calcium(method=method, return_array=False, **kwargs) ts_list.append(sh_ts) # Create MultiTimeSeries from list of TimeSeries from ..information.info_base import MultiTimeSeries return MultiTimeSeries(ts_list, name="shuffled_calcium")
[docs] def get_multicell_shuffled_spikes( self, cbunch=None, method="isi_based", return_array=True, **kwargs ): """ Get shuffled spike data for multiple cells. Parameters ---------- cbunch : int, list, or None Cell indices. If None, all cells are used. method : {'isi_based'}, default='isi_based' Shuffling method. Currently only 'isi_based' is supported for spikes. return_array : bool, default=True If True, return numpy array. If False, return MultiTimeSeries object. **kwargs Additional parameters passed to the shuffling method Returns ------- np.ndarray or MultiTimeSeries If return_array=True: Shuffled spike data with shape (n_cells, n_frames) If return_array=False: MultiTimeSeries object containing shuffled spike data""" # Check if spikes data is meaningful (not all zeros) if not np.any(self.spikes.data): raise AttributeError("Unable to shuffle spikes without meaningful spikes data") # Validate method valid_methods = ["isi_based"] if method not in valid_methods: raise ValueError( f"Invalid spike shuffling method '{method}'. Must be one of: {valid_methods}" ) cell_list = self._process_cbunch(cbunch) # Validate cell indices if any(idx < 0 or idx >= self.n_cells for idx in cell_list): raise ValueError(f"Invalid cell indices. Must be between 0 and {self.n_cells-1}") if return_array: agg_sh_data = np.zeros((len(cell_list), self.n_frames)) for i, cell_idx in enumerate(cell_list): cell = self.neurons[cell_idx] sh_data = cell.get_shuffled_spikes(method=method, return_array=True, **kwargs) agg_sh_data[i, :] = sh_data return agg_sh_data else: # Return MultiTimeSeries object ts_list = [] for cell_idx in cell_list: cell = self.neurons[cell_idx] sh_ts = cell.get_shuffled_spikes(method=method, return_array=False, **kwargs) ts_list.append(sh_ts) # Create MultiTimeSeries from list of TimeSeries from ..information.info_base import MultiTimeSeries return MultiTimeSeries(ts_list, name="shuffled_spikes")
[docs] def get_stats_slice( self, table_to_scan=None, cbunch=None, fbunch=None, sbunch=None, significance_mode=False, mode="calcium", ): """ Returns slice of accumulated statistics data (or significance data if "significance_mode=True"). Parameters ---------- table_to_scan : dict, optional Specific table to scan. If None, uses default stats or significance table cbunch : int, list, or None, optional Cell identifiers to include. If None, includes all cells fbunch : str, list, or None, optional Feature identifiers to include. If None, includes all features sbunch : str, list, or None, optional Statistics keys to include. If None, includes all statistics significance_mode : bool, optional If True, returns significance data instead of statistics. Default is False mode : str, optional Data processing mode. Default is "calcium" Returns ------- dict Nested dictionary with structure table[feat_id][cell_id][stat_key] """ cell_ids = self._process_cbunch(cbunch) feat_ids = self._process_fbunch(fbunch, allow_multifeatures=True, mode=mode) slist = self._process_sbunch(sbunch, significance_mode=significance_mode) if table_to_scan is None: if significance_mode: full_table = self.significance_tables[mode] else: full_table = self.stats_tables[mode] else: full_table = table_to_scan out_table = self._populate_cell_feat_dict(dict(), fbunch=fbunch, cbunch=cbunch) for feat_id in feat_ids: for cell_id in cell_ids: out_table[feat_id][cell_id] = {s: full_table[feat_id][cell_id][s] for s in slist} return out_table
[docs] def get_significance_slice(self, cbunch=None, fbunch=None, sbunch=None, mode="calcium"): """Extract significance test results for selected cells and features. Convenience method that retrieves statistical significance data (p-values, test statistics, etc.) for specific cell-feature combinations. This is equivalent to calling get_stats_slice with significance_mode=True. Parameters ---------- cbunch : int, list of int, or None, optional Cell indices to include. None means all cells. fbunch : int, str, list, or None, optional Feature indices/names to include. None means all features. sbunch : str, list of str, or None, optional Significance measures to extract (e.g., 'pval', 'qval', 'statistic'). None means all available measures. mode : {'calcium', 'spikes'}, default='calcium' Which data type's significance tables to use. Returns ------- dict Nested dictionary with structure: {feature: {cell: {measure: value}}}. Contains only the requested significance test results. See Also -------- get_stats_slice : More general method for extracting any statistics Examples -------- >>> import numpy as np >>> from driada.experiment.exp_base import Experiment >>> from driada.information.info_base import TimeSeries >>> >>> # Create an example experiment with selectivity data >>> calcium_data = np.random.randn(20, 1000) >>> running_speed = TimeSeries(np.random.rand(1000)) >>> exp = Experiment( ... 'example', ... calcium_data, ... None, ... {}, ... {'fps': 30.0}, ... {'running_speed': running_speed}, ... verbose=False ... ) >>> >>> # Initialize stats tables >>> exp._set_selectivity_tables('calcium') >>> >>> # Get p-values for cells 0-5 and feature 'running_speed' >>> sig_data = exp.get_significance_slice( ... cbunch=[0, 1, 2, 3, 4, 5], ... fbunch=['running_speed'], ... sbunch=['pval'] ... ) >>> # Returns nested dict structure >>> sorted(sig_data.keys()) ['running_speed'] >>> sorted(sig_data['running_speed'].keys()) [0, 1, 2, 3, 4, 5] """ return self.get_stats_slice( cbunch=cbunch, fbunch=fbunch, sbunch=sbunch, significance_mode=True, mode=mode, )
[docs] def get_feature_entropy(self, feat_id, ds=1): """ Calculates entropy of a single dynamic feature or joint entropy of multiple features. Parameters ---------- feat_id : str or tuple - str: Name of a single feature (TimeSeries or MultiTimeSeries) - tuple: Names of exactly 2 features for joint entropy calculation ds : int Downsampling factor Returns ------- float Entropy value in bits (or nats for continuous variables) Notes ----- - Single features use their native get_entropy() method - Tuples calculate joint entropy for exactly 2 variables - Joint entropy of 3+ variables is not supported (use MultiTimeSeries instead) - Continuous variables may return negative entropy values""" if isinstance(feat_id, str): # Single feature - use its get_entropy method fts = self.dynamic_features[feat_id] return fts.get_entropy(ds=ds) elif isinstance(feat_id, (tuple, list)): # Joint entropy of multiple features if len(feat_id) != 2: raise ValueError( f"Joint entropy is only supported for exactly 2 variables, got {len(feat_id)}. " f"For {len(feat_id)} variables, create a MultiTimeSeries instead." ) # Get the two features feat1_name, feat2_name = feat_id feat1 = self.dynamic_features[feat1_name] feat2 = self.dynamic_features[feat2_name] # Check for continuous components has_continuous = False if isinstance(feat1, TimeSeries) and not feat1.discrete: has_continuous = True elif isinstance(feat1, MultiTimeSeries) and not feat1.discrete: has_continuous = True if isinstance(feat2, TimeSeries) and not feat2.discrete: has_continuous = True elif isinstance(feat2, MultiTimeSeries) and not feat2.discrete: has_continuous = True if has_continuous: warnings.warn( "One or both features contain continuous components. " "Joint differential entropy may be negative and is scale-dependent." ) # For joint entropy of 2 variables, use H(X,Y) = H(X) + H(Y) - MI(X,Y) h_x = feat1.get_entropy(ds=ds) h_y = feat2.get_entropy(ds=ds) mi_xy = get_1d_mi(feat1, feat2, ds=ds) return h_x + h_y - mi_xy else: raise TypeError(f"feat_id must be str or tuple of 2 feature names, got {type(feat_id)}")
def _reconstruct_spikes( self, calcium, method, fps, spike_kwargs=None, wavelet=None, rel_wvt_times=None, use_gpu=False, ): """ Reconstruct spikes from calcium signals using specified method. Parameters ---------- calcium : np.ndarray Calcium traces, shape (n_neurons, n_timepoints) method : str or callable Reconstruction method: 'wavelet' or a callable function fps : float Sampling rate in frames per second spike_kwargs : dict, optional Method-specific parameters wavelet : Wavelet, optional Pre-computed wavelet object for batch optimization rel_wvt_times : array-like, optional Pre-computed time resolutions for batch optimization use_gpu : bool Use GPU acceleration for wavelet transform. Default False. Returns ------- spikes : np.ndarray Reconstructed spike trains""" from .spike_reconstruction import reconstruct_spikes # Convert calcium to MultiTimeSeries if needed if isinstance(calcium, np.ndarray): # Create temporary MultiTimeSeries from numpy array from ..information.info_base import TimeSeries, MultiTimeSeries # Calcium data is always continuous, so explicitly set discrete=False ts_list = [TimeSeries(calcium[i, :], discrete=False, name=f"calcium_{i}") for i in range(calcium.shape[0])] calcium_mts = MultiTimeSeries(ts_list, allow_zero_columns=True, name="calcium") else: calcium_mts = calcium # Call the unified reconstruction function with optimization support spikes_mts, metadata = reconstruct_spikes( calcium_mts, method=method, fps=fps, params=spike_kwargs, wavelet=wavelet, rel_wvt_times=rel_wvt_times, use_gpu=use_gpu, ) # Store metadata self._reconstruction_metadata = metadata # Return numpy array for backward compatibility return spikes_mts.data
[docs] def reconstruct_all_neurons( self, method="wavelet", n_iter=3, optimize_kinetics=True, hybrid_kinetics=True, wavelet=None, rel_wvt_times=None, show_progress=True, use_gpu=False, n_jobs=None, enable_parallelization=None, **kwargs, ): """Batch reconstruct spikes for all neurons with wavelet optimization. Populates neuron.asp, neuron.sp, neuron.events for each neuron. Parameters ---------- method : str Reconstruction method ('wavelet' or 'threshold'). Default 'wavelet'. n_iter : int Number of iterations for iterative detection. Default 3. optimize_kinetics : bool Optimize calcium kinetics per neuron. Default True. hybrid_kinetics : bool Optimize kinetics BEFORE reconstruction (hybrid mode). Default True. wavelet : Wavelet, optional Pre-computed wavelet object. If None, creates once and reuses. rel_wvt_times : array-like, optional Pre-computed time resolutions. If None, computes once and reuses. show_progress : bool Show progress bar. Default True. use_gpu : bool Use GPU acceleration for wavelet transform. Default False. Requires PyTorch and CuPy. Ridge extraction remains CPU-only. n_jobs : int, optional Number of parallel jobs. Default uses experiment's n_jobs setting. Use -1 for all available cores. enable_parallelization : bool, optional Enable parallel processing. Default uses experiment's setting. **kwargs Additional parameters passed to neuron.reconstruct_spikes() Returns ------- None Updates neuron objects in-place and syncs to exp.spikes """ import tqdm fps = self.fps if self.fps is not None else 20.0 # Use experiment settings if not explicitly provided if n_jobs is None: n_jobs = self._n_jobs if enable_parallelization is None: enable_parallelization = self._enable_parallelization # Pre-compute wavelet ONCE if needed if method == "wavelet": from ssqueezepy.wavelets import Wavelet, time_resolution from .wavelet_event_detection import get_adaptive_wavelet_scales if method == "wavelet" and wavelet is None: wavelet = Wavelet(("gmw", {"gamma": 3, "beta": 2, "centered_scale": True}), N=8196) manual_scales = get_adaptive_wavelet_scales(fps) rel_wvt_times = [ time_resolution(wavelet, scale=sc, nondim=False, min_decay=200) for sc in manual_scales ] # Helper functions for parallel execution def _optimize_kinetics(neuron, fps): """Optimize calcium kinetics for a single neuron. Parameters ---------- neuron : Neuron Neuron object whose kinetics will be optimized in-place. fps : float Sampling rate in frames per second. Returns ------- Neuron The same neuron object, after kinetics optimization. """ neuron.optimize_kinetics(method="direct", fps=fps) return neuron def _reconstruct_spikes(neuron, method, n_iter, fps, wavelet, rel_wvt_times, use_gpu, kwargs): """Reconstruct spikes for a single neuron. Parameters ---------- neuron : Neuron Neuron object whose spikes will be reconstructed in-place. method : str Reconstruction method ('wavelet' or 'threshold'). n_iter : int Number of iterations for iterative detection. fps : float Sampling rate in frames per second. wavelet : Wavelet or None Pre-computed wavelet object for wavelet-based reconstruction. rel_wvt_times : list or None Pre-computed time resolutions for each wavelet scale. use_gpu : bool Whether to use GPU acceleration for wavelet transform. kwargs : dict Additional keyword arguments passed to ``neuron.reconstruct_spikes()``. Returns ------- Neuron The same neuron object, after spike reconstruction. """ neuron.reconstruct_spikes( method=method, iterative=True, n_iter=n_iter, fps=fps, wavelet=wavelet, rel_wvt_times=rel_wvt_times, use_gpu=use_gpu, **kwargs, ) return neuron # Hybrid kinetics: optimize BEFORE reconstruction _should_parallelize = ( enable_parallelization and len(self.neurons) >= MIN_NEURONS_FOR_PARALLEL ) if hybrid_kinetics and optimize_kinetics: if _should_parallelize: from ..utils.parallel import parallel_executor, delayed with parallel_executor(n_jobs=n_jobs) as parallel: self.neurons = parallel( delayed(_optimize_kinetics)(neuron, fps) for neuron in tqdm.tqdm(self.neurons, disable=not show_progress, desc="Kinetics") ) else: for neuron in tqdm.tqdm(self.neurons, disable=not show_progress, desc="Kinetics"): _optimize_kinetics(neuron, fps) # Reconstruct all neurons with shared wavelet if _should_parallelize: from ..utils.parallel import parallel_executor, delayed with parallel_executor(n_jobs=n_jobs) as parallel: self.neurons = parallel( delayed(_reconstruct_spikes)( neuron, method, n_iter, fps, wavelet, rel_wvt_times, use_gpu, kwargs ) for neuron in tqdm.tqdm(self.neurons, disable=not show_progress, desc="Reconstruction") ) else: for neuron in tqdm.tqdm(self.neurons, disable=not show_progress, desc="Reconstruction"): _reconstruct_spikes( neuron, method, n_iter, fps, wavelet, rel_wvt_times, use_gpu, kwargs ) # Post-reconstruction kinetics if not hybrid if not hybrid_kinetics and optimize_kinetics: if _should_parallelize: from ..utils.parallel import parallel_executor, delayed with parallel_executor(n_jobs=n_jobs) as parallel: self.neurons = parallel( delayed(_optimize_kinetics)(neuron, fps) for neuron in tqdm.tqdm(self.neurons, disable=not show_progress, desc="Kinetics") ) else: for neuron in tqdm.tqdm(self.neurons, disable=not show_progress, desc="Kinetics"): _optimize_kinetics(neuron, fps) # Sync neuron.sp → exp.spikes self._update_spike_data_from_neurons()
def _update_spike_data_from_neurons(self): """Sync neuron.sp arrays back to exp.spikes.""" from ..information.info_base import TimeSeries, MultiTimeSeries import numpy as np spike_ts_list = [] for i, neuron in enumerate(self.neurons): if neuron.sp is not None: spike_ts_list.append(neuron.sp) else: spike_ts_list.append(TimeSeries(np.zeros(self.n_frames), discrete=True, name=f"neuron_{i}_sp_zero")) self.spikes = MultiTimeSeries(spike_ts_list, allow_zero_columns=True, name="spikes")
[docs] def get_significant_neurons( self, min_nspec=1, cbunch=None, fbunch=None, mode="calcium", override_intense_significance=False, pval_thr=0.05, multicomp_correction=None, significance_update=False, ): """ Returns a dict with neuron ids as keys and their significantly correlated features as values Only neurons with "min_nspec" or more significantly correlated features will be returned Parameters ---------- min_nspec : int Minimum number of significantly correlated features required cbunch : int, list or None Cell indices to analyze. By default (None), all neurons will be checked fbunch : str, list or None Feature names to check. By default (None), all features will be checked mode : str Data type: 'calcium' or 'spikes' override_intense_significance : bool, optional If True, recompute significance using pval_thr and multicomp_correction instead of using pre-computed INTENSE significance. Default is False. pval_thr : float, optional P-value threshold for significance testing. Default is 0.05. Only used if override_intense_significance=True. multicomp_correction : str or None, optional Multiple comparison correction method. Default is None (no correction). Options: None, 'bonferroni', 'holm', 'fdr_bh' Only used if override_intense_significance=True. significance_update : bool, optional If True, update the significance tables with new thresholds. If False (default), only use new thresholds for current query. Only used if override_intense_significance=True. Returns ------- dict Dictionary with neuron IDs as keys and lists of significant features as values""" cell_ids = self._process_cbunch(cbunch) feat_ids = self._process_fbunch(fbunch, allow_multifeatures=True, mode=mode) # Check relevance only for requested cells and features relevance = [ self._check_stats_relevance(cell_id, feat_id, mode=mode) for cell_id in cell_ids for feat_id in feat_ids ] if not np.all(np.array(relevance)): raise ValueError("Stats relevance error") cell_feat_dict = {cell_id: [] for cell_id in cell_ids} if override_intense_significance: # Collect all p-values for multiple comparison correction all_pvals = [] cell_feat_pvals = {} for cell_id in cell_ids: cell_feat_pvals[cell_id] = {} for feat_id in feat_ids: pval = self.stats_tables[mode][feat_id][cell_id].get("pval", 1.0) cell_feat_pvals[cell_id][feat_id] = pval all_pvals.append(pval) # Calculate corrected threshold if multicomp_correction is None: corrected_threshold = pval_thr elif multicomp_correction == "bonferroni": corrected_threshold = get_multicomp_correction_thr( pval_thr, mode="bonferroni", nhyp=len(all_pvals) ) elif multicomp_correction in ["holm", "fdr_bh"]: corrected_threshold = get_multicomp_correction_thr( pval_thr, mode=multicomp_correction, all_pvals=all_pvals ) else: raise ValueError( f"Unknown multicomp_correction method: {multicomp_correction}. " "Options: None, 'bonferroni', 'holm', 'fdr_bh'" ) # Determine significance based on new threshold for cell_id in cell_ids: for feat_id in feat_ids: pval = cell_feat_pvals[cell_id][feat_id] # Check if significant according to new threshold is_significant = pval < corrected_threshold if is_significant: cell_feat_dict[cell_id].append(feat_id) # Update significance tables if requested if significance_update: self.significance_tables[mode][feat_id][cell_id]["stage2"] = is_significant self.significance_tables[mode][feat_id][cell_id]["pval_thr"] = pval_thr self.significance_tables[mode][feat_id][cell_id][ "multicomp_correction" ] = multicomp_correction self.significance_tables[mode][feat_id][cell_id][ "corrected_pval_thr" ] = corrected_threshold else: # Use pre-computed INTENSE significance for cell_id in cell_ids: for feat_id in feat_ids: if self.significance_tables[mode][feat_id][cell_id]["stage2"]: cell_feat_dict[cell_id].append(feat_id) # filter out cells without enough specializations final_cell_feat_dict = { cell_id: cell_feat_dict[cell_id] for cell_id in cell_ids if len(cell_feat_dict[cell_id]) >= min_nspec } return final_cell_feat_dict
[docs] def store_embedding(self, embedding, method_name, data_type="calcium", metadata=None): """ Store dimensionality reduction embedding in the experiment. This method stores a computed embedding in the experiment's internal embeddings dictionary. Previous embeddings with the same method_name and data_type will be overwritten without warning. Parameters ---------- embedding : np.ndarray The embedding array, shape (n_timepoints, n_components). The number of timepoints must match self.n_frames or self.n_frames//ds if downsampling was used. method_name : str Name of the DR method (e.g., 'pca', 'umap', 'isomap'). This serves as the key for storing and retrieving the embedding. data_type : str, optional Type of data used ('calcium' or 'spikes'). Default is 'calcium'. metadata : dict, optional Additional metadata about the embedding. Common keys include: - 'ds': Downsampling factor used - 'n_components': Number of components - 'neuron_indices': Indices of neurons used - 'method_params': Parameters specific to the DR method Raises ------ ValueError If data_type is not 'calcium' or 'spikes'. ValueError If embedding timepoints don't match expected frames. The expected number is self.n_frames//ds where ds is extracted from metadata (default 1). Notes ----- The embedding is stored in self.embeddings[data_type][method_name] as a dictionary containing: - 'data': The embedding array - 'metadata': The provided metadata dict (or empty dict) - 'timestamp': Current time when stored (np.datetime64) - 'shape': Shape tuple of the embedding array Previous embeddings with the same method_name are silently overwritten. Examples -------- >>> import numpy as np >>> from driada.experiment.exp_base import Experiment >>> >>> # Create a simple experiment >>> calcium_data = np.random.randn(10, 1000) >>> exp = Experiment('test', calcium_data, None, {}, ... {'fps': 30.0}, {}, verbose=False) >>> >>> # Store a PCA embedding >>> embedding = np.random.randn(1000, 3) # 1000 timepoints, 3 components >>> exp.store_embedding(embedding, 'pca', metadata={'n_components': 3}) >>> >>> # Verify storage >>> 'pca' in exp.embeddings['calcium'] True >>> >>> # Store a downsampled UMAP embedding >>> # If experiment has 1000 frames and ds=5, embedding should have 200 rows >>> downsampled_embedding = np.random.randn(200, 2) >>> exp.store_embedding(downsampled_embedding, 'umap', ... metadata={'ds': 5, 'n_neighbors': 30}) See Also -------- get_embedding : Retrieve stored embeddings create_embedding : Create and store embeddings in one step""" if data_type not in ["calcium", "spikes"]: raise ValueError("data_type must be 'calcium' or 'spikes'") # Check if embedding matches expected timepoints (accounting for downsampling) ds = metadata.get("ds", 1) if metadata else 1 # data[:, ::ds] gives ceil(n_frames / ds) timepoints expected_frames = -(-self.n_frames // ds) # ceiling division if embedding.shape[0] != expected_frames: raise ValueError( f"Embedding timepoints ({embedding.shape[0]}) must match expected frames " f"({expected_frames} = ceil({self.n_frames} / ds={ds}))" ) self.embeddings[data_type][method_name] = { "data": embedding, "metadata": metadata or {}, "timestamp": np.datetime64("now"), "shape": embedding.shape, }
[docs] def create_embedding( self, method: str, n_components: int = 2, data_type: str = "calcium", neuron_selection: Optional[Union[str, List[int]]] = None, **dr_kwargs, ) -> np.ndarray: """ Create dimensionality reduction embedding and store it. Notes ----- This method modifies the experiment's state by storing the computed embedding. Previous embeddings with the same method name will be overwritten. The method uses MultiTimeSeries internally for data handling and applies the dimensionality reduction through the MVData interface. Parameters ---------- method : str DR method name ('pca', 'umap', 'isomap', etc.). n_components : int, optional Number of embedding dimensions. Default is 2. data_type : str, optional Type of data to use ('calcium' or 'spikes'). Default is 'calcium'. neuron_selection : str, list or None, optional How to select neurons: - None or 'all': Use all neurons - 'significant': Use only significantly selective neurons - List of integers: Use specific neuron indices **dr_kwargs Additional arguments for the DR method (e.g., n_neighbors, min_dist). Returns ------- embedding : np.ndarray The embedding array, shape (n_timepoints, n_components). Raises ------ ValueError If n_components is not positive, data_type is invalid, downsampling factor 'ds' is not an integer, neuron indices are out of bounds, significant neurons requested without selectivity analysis, or embedding method drops timepoints. AttributeError If spike data requested but not available. Examples -------- >>> import numpy as np >>> from driada.experiment.exp_base import Experiment >>> from driada.information.info_base import TimeSeries >>> >>> # Create experiment with some data >>> calcium_data = np.random.randn(25, 1000) >>> speed = TimeSeries(np.random.rand(1000)) >>> exp = Experiment('test', calcium_data, None, {}, ... {'fps': 30.0}, {'speed': speed}, verbose=False) >>> >>> # Create PCA embedding using all neurons >>> embedding = exp.create_embedding('pca', n_components=10) Calculating PCA embedding... >>> embedding.shape (1000, 10) >>> >>> # Create downsampled PCA with specific neurons >>> embedding = exp.create_embedding('pca', n_components=2, ... neuron_selection=[0, 1, 2, 3, 4], ds=10) Calculating PCA embedding... See Also -------- store_embedding : Store computed embeddings get_embedding : Retrieve stored embeddings get_significant_neurons : Get neurons with significant selectivity """ from ..information.info_base import MultiTimeSeries from ..utils.data import check_positive # Validate inputs check_positive(n_components=n_components) if data_type not in ["calcium", "spikes"]: raise ValueError("data_type must be 'calcium' or 'spikes'") # Select neurons if neuron_selection is None or neuron_selection == "all": neuron_indices = np.arange(self.n_cells) elif neuron_selection == "significant": has_selectivity = ( hasattr(self, "stats_tables") and self.stats_tables is not None and data_type in self.stats_tables and len(self.stats_tables[data_type]) > 0 ) if not has_selectivity: raise ValueError("Cannot select significant neurons without selectivity analysis") sig_neurons = self.get_significant_neurons() neuron_indices = np.array(list(sig_neurons.keys())) if len(neuron_indices) == 0: logging.warning("No significant neurons found, using all neurons") neuron_indices = np.arange(self.n_cells) else: neuron_indices = np.array(neuron_selection) # Validate neuron indices are within bounds if len(neuron_indices) > 0: if np.any(neuron_indices < 0) or np.any(neuron_indices >= self.n_cells): raise ValueError(f"Neuron indices must be in range [0, {self.n_cells-1}]") # Get neural data - calcium and spikes are already MultiTimeSeries if data_type == "calcium": multi_ts = self.calcium else: if not hasattr(self, "spikes") or self.spikes is None: raise AttributeError("Experiment has no spike data") multi_ts = self.spikes # Create subset MultiTimeSeries with selected neurons if len(neuron_indices) != self.n_cells: subset_data = multi_ts.data[neuron_indices, :] multi_ts = MultiTimeSeries( subset_data, discrete=multi_ts.discrete, allow_zero_columns=(data_type == "spikes"), name=multi_ts.name if hasattr(multi_ts, 'name') else data_type ) # Apply downsampling if requested ds = dr_kwargs.pop("ds", 1) # Remove 'ds' from dr_kwargs if ds > 1: check_positive(ds=ds) if not isinstance(ds, int): raise ValueError("Downsampling factor 'ds' must be an integer") # Create downsampled MultiTimeSeries downsampled_data = multi_ts.data[:, ::ds] multi_ts = MultiTimeSeries( downsampled_data, discrete=multi_ts.discrete, allow_zero_columns=(data_type == "spikes"), name=multi_ts.name if hasattr(multi_ts, 'name') else data_type ) logging.info(f"Downsampling data by factor {ds}: {multi_ts.data.shape[1]} timepoints") # Prepare parameters for dimensionality reduction params = {"dim": n_components} params.update(dr_kwargs) # Add all additional parameters # Get embedding using MultiTimeSeries/MVData method embedding_obj = multi_ts.get_embedding(method=method, **params) embedding = embedding_obj.coords.T # Transpose to (n_timepoints, n_components) # Check if embedding has all timepoints (accounting for downsampling) expected_frames = -(-self.n_frames // ds) # ceiling division if embedding.shape[0] < expected_frames: n_missing = expected_frames - embedding.shape[0] raise ValueError( f"{method} embedding dropped {n_missing} timepoints due to graph disconnection. " f"This is not supported for INTENSE analysis. Try increasing n_neighbors or using a different method." ) # Store metadata metadata = { "method": method, "n_components": n_components, "neuron_selection": neuron_selection, "neuron_indices": neuron_indices.tolist(), "n_neurons": len(neuron_indices), "dr_params": dr_kwargs, "data_type": data_type, "ds": ds, # Store downsampling factor } # Store in experiment self.store_embedding(embedding, method, data_type, metadata) logging.info( f"Created {method} embedding with {n_components} components " f"using {len(neuron_indices)} neurons" ) return embedding
[docs] def get_embedding(self, method_name, data_type="calcium"): """ Retrieve stored embedding. This method retrieves a previously stored dimensionality reduction embedding from the experiment's embeddings dictionary. The returned dictionary contains the embedding data along with metadata and timestamp. Parameters ---------- method_name : str Name of the DR method to retrieve (e.g., 'pca', 'umap'). data_type : str, optional Type of data used ('calcium' or 'spikes'). Default is 'calcium'. Returns ------- dict Dictionary containing: - 'data': The embedding array (n_timepoints, n_components) - 'metadata': Dict with embedding parameters and settings - 'timestamp': np.datetime64 when the embedding was stored - 'shape': Tuple with shape of the embedding array Raises ------ ValueError If data_type is not 'calcium' or 'spikes'. KeyError If no embedding found for the specified method and data type. Notes ----- To see available embeddings, check exp.embeddings[data_type].keys(). The returned dictionary is a reference to the stored data, so modifications will affect the stored embedding. Examples -------- >>> import numpy as np >>> from driada.experiment.exp_base import Experiment >>> >>> # Create experiment and store an embedding >>> calcium_data = np.random.randn(10, 1000) >>> exp = Experiment('test', calcium_data, None, {}, ... {'fps': 30.0}, {}, verbose=False) >>> >>> # Store an embedding first >>> embedding = np.random.randn(1000, 3) >>> exp.store_embedding(embedding, 'pca', metadata={'n_components': 3}) >>> >>> # Retrieve the stored PCA embedding >>> embedding_dict = exp.get_embedding('pca') >>> embedding_data = embedding_dict['data'] >>> print(f"Embedding shape: {embedding_dict['shape']}") Embedding shape: (1000, 3) >>> >>> # Check available embeddings before retrieval >>> available = list(exp.embeddings['calcium'].keys()) >>> print(f"Available embeddings: {available}") Available embeddings: ['pca'] See Also -------- store_embedding : Store embeddings in the experiment create_embedding : Create and store embeddings in one step""" if data_type not in ["calcium", "spikes"]: raise ValueError("data_type must be 'calcium' or 'spikes'") if method_name not in self.embeddings[data_type]: raise KeyError( f"No embedding found for method '{method_name}' with data_type '{data_type}'" ) return self.embeddings[data_type][method_name]
[docs] def compute_rdm( self, items, data_type="calcium", metric="correlation", average_method="mean", use_cache=True, ): """ Compute RDM with caching support. Parameters ---------- items : str Name of dynamic feature to use as condition labels data_type : str, default 'calcium' Type of data to use ('calcium' or 'spikes') metric : str, default 'correlation' Distance metric for RDM computation average_method : str, default 'mean' How to average within conditions ('mean' or 'median') use_cache : bool, default True Whether to use cached results Returns ------- rdm : np.ndarray Representational dissimilarity matrix labels : np.ndarray The unique labels/conditions""" # Generate cache key cache_key = (items, data_type, metric, average_method) # Check cache if use_cache and cache_key in self._rdm_cache: return self._rdm_cache[cache_key] # Import here to avoid circular dependency from ..rsa.integration import compute_experiment_rdm # Compute RDM result = compute_experiment_rdm( self, items, data_type=data_type, metric=metric, average_method=average_method, ) # Cache result if use_cache: self._rdm_cache[cache_key] = result return result
[docs] def clear_rdm_cache(self): """Clear the representational dissimilarity matrix (RDM) cache. Removes all cached RDM computations to free memory or force recalculation with updated data. This is necessary after modifying the underlying neural data or when memory usage is a concern. Notes ----- The RDM cache stores previously computed dissimilarity matrices to avoid expensive recomputation. Clear the cache when: - Neural data has been modified or reprocessed - Embeddings have been updated - Memory usage needs to be reduced - You want to force fresh computation with different parameters After clearing, subsequent calls to compute_rdm() will recalculate the RDM from scratch, which may be computationally expensive for large datasets. See Also -------- compute_rdm : Method that uses and populates the RDM cache Examples -------- >>> import numpy as np >>> from driada.experiment.exp_base import Experiment >>> from driada.information.info_base import TimeSeries >>> >>> # Create experiment with a categorical feature >>> calcium_data = np.random.randn(10, 1000) >>> conditions = TimeSeries(np.repeat([0, 1, 2], [333, 333, 334]), discrete=True) >>> exp = Experiment('test', calcium_data, None, {}, ... {'fps': 30.0}, {'conditions': conditions}, verbose=False) >>> >>> # Compute RDM (will be cached) >>> rdm, labels = exp.compute_rdm('conditions') >>> >>> # Update embedding and clear cache >>> new_embedding = np.random.randn(1000, 3) >>> exp.store_embedding(new_embedding, 'pca') >>> exp.clear_rdm_cache() >>> >>> # Verify cache is empty >>> len(exp._rdm_cache) 0 """ self._rdm_cache = {}