Source code for driada.intense.pipelines

import time

import numpy as np

from .stats import stats_not_empty, DEFAULT_METRIC_DISTR_TYPE
from .intense_base import compute_me_stats, IntenseResults
from ..information.info_base import TimeSeries, MultiTimeSeries, calc_signal_ratio
from .disentanglement import disentangle_all_selectivities, DEFAULT_MULTIFEATURE_MAP
from ..experiment.exp_base import DEFAULT_STATS


[docs] def substitute_circular_with_2d(feat_ids, exp, verbose=False): """Substitute circular features with their _2d (cos, sin) counterparts. For features detected as circular that have a corresponding `{name}_2d` MultiTimeSeries in the experiment, replaces the feature ID with the _2d version. Parameters ---------- feat_ids : list List of feature IDs (strings or tuples for multi-features). exp : Experiment Experiment object containing dynamic_features. verbose : bool, default=False If True, print substitution information. Returns ------- tuple (new_feat_ids, substitutions) where substitutions is a list of (original, substituted) tuples. Examples -------- >>> # Assuming exp has circular feature 'headdirection' with _2d version >>> feat_ids = ['headdirection', 'speed'] # doctest: +SKIP >>> new_ids, subs = substitute_circular_with_2d(feat_ids, exp) # doctest: +SKIP >>> new_ids # doctest: +SKIP ['headdirection_2d', 'speed'] """ substituted = [] new_feat_ids = [] feat_id_set = set(feat_ids) for feat_id in feat_ids: if isinstance(feat_id, str) and not feat_id.endswith("_2d"): name_2d = f"{feat_id}_2d" # Check if _2d version exists in experiment if name_2d in exp.dynamic_features: # Verify original is circular orig_ts = exp.dynamic_features.get(feat_id) if ( isinstance(orig_ts, TimeSeries) and orig_ts.type_info and orig_ts.type_info.is_circular ): # Drop original — _2d is already in the list or will be added if name_2d in feat_id_set: # _2d already present, just drop the original substituted.append((feat_id, name_2d)) else: # _2d not in list, substitute new_feat_ids.append(name_2d) substituted.append((feat_id, name_2d)) continue new_feat_ids.append(feat_id) if verbose and substituted: print("Circular features substituted with _2d versions:") for orig, sub in substituted: print(f" '{orig}' -> '{sub}'") return new_feat_ids, substituted
[docs] def compute_cell_feat_significance( exp, cell_bunch=None, feat_bunch=None, data_type="calcium", metric="mi", mi_estimator="gcmi", mi_estimator_kwargs=None, mode="two_stage", n_shuffles_stage1=100, n_shuffles_stage2=10000, metric_distr_type=DEFAULT_METRIC_DISTR_TYPE, noise_ampl=1e-3, ds=1, use_precomputed_stats=True, save_computed_stats=True, force_update=False, topk1=1, topk2=5, multicomp_correction="holm", pval_thr=0.01, find_optimal_delays=True, skip_delays=[], shift_window=2, verbose=True, enable_parallelization=True, n_jobs=-1, seed=42, with_disentanglement=False, feat_feat_pval_thr=0.01, multifeature_map=None, duplicate_behavior="ignore", engine="auto", store_random_shifts=False, profile=False, pre_filter_func=None, post_filter_func=None, filter_kwargs=None, remove_anti_selective=True, use_circular_2d=True, ) -> tuple: """ Calculates significant neuron-feature pairs Parameters ---------- exp : Experiment Experiment object to read and write data from cell_bunch : int, iterable or None, optional Neuron indices. By default, (cell_bunch=None), all neurons will be taken feat_bunch : str, iterable or None, optional Feature names. By default, (feat_bunch=None), all single features will be taken data_type : str, optional Data type used for INTENSE computations. Can be 'calcium' or 'spikes'. Default is 'calcium' metric : str, optional Similarity metric between TimeSeries. Default is 'mi' mi_estimator : str, optional Mutual information estimator to use when metric='mi'. Options: 'gcmi' or 'ksg'. Default is 'gcmi' mi_estimator_kwargs : dict, optional Additional keyword arguments passed to the MI estimator function. mode : str, optional Computation mode. 3 modes are available: - ``'stage1'``: perform preliminary scanning with "n_shuffles_stage1" shuffles only. Rejects strictly non-significant neuron-feature pairs, does not give definite results about significance of the others. - ``'stage2'``: skip stage 1 and perform full-scale scanning ("n_shuffles_stage2" shuffles) of all neuron-feature pairs. Gives definite results, but can be very time-consuming. Also reduces statistical power of multiple comparison tests, since the number of hypotheses is very high. - ``'two_stage'``: prune non-significant pairs during stage 1 and perform thorough testing for the rest during stage 2. Recommended mode. Default is 'two_stage' n_shuffles_stage1 : int, optional Number of shuffles for first stage. Default is 100 n_shuffles_stage2 : int, optional Number of shuffles for second stage. Default is 10000 metric_distr_type : str, optional Distribution type for shuffled metric null distribution. Options: - 'gamma_zi' (default): Zero-inflated gamma distribution. Explicitly models the probability mass at zero that commonly occurs in MI null distributions. Provides superior goodness-of-fit and accurate parameter estimation without requiring artificial noise. Recommended for all analyses. - 'gamma': Standard gamma distribution with small noise added (noise_ampl) to handle zeros. Provided for backward compatibility. Less statistically principled than 'gamma_zi'. - Other scipy.stats distributions: 'lognorm', 'norm', etc. are supported but not recommended for MI distributions. **Recommendation**: Use 'gamma_zi' (default) for new analyses. It achieves equivalent detection performance while providing statistically correct goodness-of-fit and accurate parameter recovery. Default: 'gamma_zi' noise_ampl : float, optional Small noise amplitude added to MI values for numerical stability (only used with metric_distr_type='gamma'). When using 'gamma_zi', this parameter is automatically set to 0 since zero-inflated gamma handles zeros explicitly without requiring artificial noise. Default: 1e-3 ds : int, optional Downsampling constant. Every "ds" point will be taken from the data time series. Reduces the computational load, but needs caution since with large "ds" some important information may be lost. Experiment class performs an internal check for this effect. Default is 1 use_precomputed_stats : bool, optional Whether to use stats saved in Experiment instance. Stats are accumulated separately for stage1 and stage2. Notes on stats data rewriting (if save_computed_stats=True): If you want to recalculate stage1 results only, use "use_precomputed_stats=False" and "mode='stage1'". Stage 2 stats will be erased since they will become irrelevant. If you want to recalculate stage2 results only, use "use_precomputed_stats=True" and "mode='stage2'" or "mode='two-stage'" If you want to recalculate everything, use "use_precomputed_stats=False" and "mode='two-stage'". Default is True save_computed_stats : bool, optional Whether to save computed stats to Experiment instance. Default is True force_update : bool, optional Whether to force saved statistics data update in case the collision between actual data hashes and saved stats data hashes is found (for example, if neuronal or behavior data has been changed externally). Default is False topk1 : int, optional True MI for stage 1 should be among topk1 MI shuffles. Default is 1 topk2 : int, optional True MI for stage 2 should be among topk2 MI shuffles. Default is 5 multicomp_correction : str or None, optional Type of multiple comparison correction. Supported types are None (no correction), "bonferroni", "holm", and "fdr_bh" (Benjamini-Hochberg FDR). Default is 'holm' pval_thr : float, optional P-value threshold. If multicomp_correction=None, this is a p-value for a single pair. Otherwise it is a FWER significance level. Default is 0.01 find_optimal_delays : bool, optional Allows slight shifting (not more than +- shift_window) of time series, selects a shift with the highest MI as default. Default is True skip_delays : list, optional List of features for which delays are not applied (set to 0). Only features that exist in feat_bunch will be processed. Has no effect if find_optimal_delays = False. Default is [] shift_window : int, optional Window for optimal shift search (seconds). Optimal shift (in frames) will lie in the range -shift_window*fps <= opt_shift <= shift_window*fps. Has no effect if find_optimal_delays = False. Default is 2 verbose : bool, optional Whether to print progress messages. Default is True enable_parallelization : bool, optional Whether to enable parallel processing. Default is True n_jobs : int, optional Number of parallel jobs. -1 means use all processors. Default is -1 seed : int, optional Random seed for reproducibility. Default is 42 with_disentanglement : bool, optional If True, performs a full INTENSE pipeline with mixed selectivity analysis: 1. Computes behavioral feature-feature significance 2. Computes neuron-feature significance 3. Disentangles mixed selectivities using behavioral correlations. Default is False feat_feat_pval_thr : float, optional P-value threshold for feature-feature significance testing during disentanglement. Separate from cell-feat ``pval_thr`` because the number of feature pairs (~100-200) is much smaller than neuron-feature pairs (thousands), so a stricter threshold is unnecessary. Only used when ``with_disentanglement=True``. Default is 0.01 multifeature_map : dict or None, optional Mapping from multifeature tuples to aggregated names for disentanglement. If None, uses DEFAULT_MULTIFEATURE_MAP from disentanglement module. Only used when with_disentanglement=True. Default is None duplicate_behavior : str, optional How to handle duplicate TimeSeries in neuron or feature bunches. - 'ignore': Process duplicates normally (default) - 'raise': Raise an error if duplicates are found - 'warn': Print a warning but continue processing. Default is 'ignore' engine : {'auto', 'fft', 'loop'}, optional Computation engine for MI shuffles: - 'auto': Use FFT when applicable (univariate continuous GCMI with nsh >= 50) - 'fft': Force FFT (raises error if not applicable) - 'loop': Force per-shift loop (original behavior) FFT provides ~100x speedup for Stage 2. Default is 'auto' store_random_shifts : bool, optional Whether to store the random shift indices used during shuffle computation. When False (default), random_shifts1 and random_shifts2 arrays are not stored, saving significant memory (~400MB for typical datasets with N=500, M=20). Set to True if you need the shift indices for debugging or reproducibility analysis. Default is False profile : bool, optional Whether to collect internal timing information. When True, info['timings'] will contain execution times (in seconds) for: - 'stage1_delay_optimization': delay optimization (if find_optimal_delays=True) - 'stage1_pair_scanning': stage 1 pair scanning - 'stage2_pair_scanning': stage 2 pair scanning (if applicable) - 'fft_type_counts': Dictionary of FFT type usage counts - 'disentanglement': disentanglement analysis (if with_disentanglement=True) - 'total': sum of all timing sections Default is False pre_filter_func : callable or None, optional Population-level filter function (or composed filter) to run BEFORE disentanglement parallel processing. Only used when with_disentanglement=True. The filter mutates neuron selectivities and pre-computes pair decisions. Signature:: def pre_filter_func( neuron_selectivities, # dict: {neuron_id: [feat1, feat2, ...]} - MUTATE pair_decisions, # dict: {neuron_id: {(f1, f2): 0/0.5/1}} - MUTATE renames, # dict: {neuron_id: {new_name: (old1, old2)}} - MUTATE cell_feat_stats, # Pre-computed MI values (READ ONLY) feat_feat_significance, # Binary matrix (READ ONLY) feat_names, # List of feature names (READ ONLY) **kwargs, # User-provided extra arguments from filter_kwargs ): ... Default: None (no filtering). post_filter_func : callable or None, optional Population-level filter function to run AFTER disentanglement parallel processing. Can modify pair results (e.g., tie-breaking). Only used when with_disentanglement=True. Signature:: def post_filter_func( per_neuron_disent, # dict: {nid: {'pairs': {...}, ...}} - MUTATE cell_feat_stats, # Pre-computed MI values (READ ONLY) feat_names, # List of feature names (READ ONLY) **kwargs, # User-provided extra arguments ): ... Default: None (no post-filtering). filter_kwargs : dict or None, optional Dictionary of keyword arguments to pass to pre_filter_func and post_filter_func. Can include pre-extracted data like calcium_data, feature_data, thresholds, etc. Only used when with_disentanglement=True. Default: None. use_circular_2d : bool, default=True If True, automatically substitute circular features with their `_2d` counterparts (cos, sin representation) for MI computation. This improves MI estimation accuracy for circular variables like head direction. Requires that `create_circular_2d=True` was used during experiment loading. Returns ------- stats: dict of dict of dicts Outer dict: cells, inner dict: dynamic features, last dict: stats. Can be easily converted to pandas DataFrame by pd.DataFrame(stats) significance: dict of dict of bools Significance results for each neuron-feature pair info: dict Additional information from compute_me_stats intense_res: IntenseResults Complete results object disentanglement_results : dict (only if with_disentanglement=True) Contains: - 'feat_feat_significance': Feature-feature significance matrix - 'disent_matrix': Disentanglement results matrix - 'count_matrix': Count matrix from disentanglement - 'per_neuron_disent': Per-neuron detailed results dict mapping neuron_id to 'pairs', 'renames', and 'final_sels' sub-dicts. - 'feature_names': List of feature names - 'summary': Summary statistics from disentanglement Raises ------ ValueError If data_type is not 'calcium' or 'spikes' If features are not found in experiment Notes ----- - shift_window is converted from seconds to frames using exp.fps - Updates exp.optimal_nf_delays as a side effect - Relative MI values are computed using appropriate neural data entropy Examples -------- >>> from driada.experiment.synthetic import generate_synthetic_exp >>> import numpy as np >>> >>> # Create small test experiment >>> exp = generate_synthetic_exp(n_dfeats=2, n_cfeats=1, nneurons=3, ... duration=60, fps=10, seed=42, verbose=False) >>> >>> # Basic neuron-feature analysis (stage1 for speed) >>> stats, sig, info, res, _ = compute_cell_feat_significance( ... exp, ... cell_bunch=[0, 1], ... feat_bunch=['d_feat_0'], ... mode='stage1', ... n_shuffles_stage1=10, ... verbose=False ... ) # doctest: +ELLIPSIS ... >>> len(stats) # Number of neurons analyzed 2 >>> 'd_feat_0' in stats[0] # Feature present in results True >>> >>> # With disentanglement analysis >>> result = compute_cell_feat_significance( ... exp, ... cell_bunch=[0, 1], ... mode='stage1', ... n_shuffles_stage1=10, ... with_disentanglement=True, ... verbose=False ... ) # doctest: +ELLIPSIS ... >>> len(result) # Returns 5 values with disentanglement 5 >>> stats, sig, info, res, disent = result >>> 'disent_matrix' in disent True""" exp.check_ds(ds) cell_ids = exp._process_cbunch(cell_bunch) feat_ids = exp._process_fbunch(feat_bunch, allow_multifeatures=True, mode=data_type) # Substitute circular features with _2d counterparts for better MI estimation if use_circular_2d: feat_ids, _ = substitute_circular_with_2d(feat_ids, exp, verbose=verbose) cells = [exp.neurons[cell_id] for cell_id in cell_ids] if data_type == "calcium": signals = [cell.ca for cell in cells] elif data_type == "spikes": signals = [cell.sp for cell in cells] else: raise ValueError('"data_type" can be either "calcium" or "spikes"') # min_shifts = [int(cell.get_t_off() * MIN_CA_SHIFT) for cell in cells] feats = [] for feat_id in feat_ids: if isinstance(feat_id, str): if feat_id not in exp.dynamic_features: raise ValueError( f"Feature '{feat_id}' not found in experiment. Available features: {list(exp.dynamic_features.keys())}" ) ts = exp.dynamic_features[feat_id] feats.append(ts) elif isinstance(feat_id, tuple): for f in feat_id: if f not in exp.dynamic_features: raise ValueError( f"Feature '{f}' not found in experiment. Available features: {list(exp.dynamic_features.keys())}" ) parts = [exp.dynamic_features[f] for f in feat_id] # Create MultiTimeSeries with name from tuple mts_name = "_".join(str(f) for f in feat_id) mts = MultiTimeSeries(parts, name=mts_name) feats.append(mts) else: raise ValueError("Unknown feature id type") n, t, f = len(cells), exp.n_frames, len(feats) precomputed_mask_stage1 = np.ones((n, f)) precomputed_mask_stage2 = np.ones((n, f)) if not exp.selectivity_tables_initialized: exp._set_selectivity_tables(data_type, cbunch=cell_ids, fbunch=feat_ids) if use_precomputed_stats: if verbose: print("Retrieving saved stats data...") # 0 in mask values means precomputed results are found, calculation will be skipped. # 1 in mask values means precomputed results are not found or incomplete, calculation will proceed. for i, cell_id in enumerate(cell_ids): for j, feat_id in enumerate(feat_ids): try: pair_stats = exp.get_neuron_feature_pair_stats(cell_id, feat_id, mode=data_type) except (ValueError, KeyError): if isinstance(feat_id, str): raise ValueError( f"Unknown single feature in feat_bunch: {feat_id}. Check initial data" ) else: exp._add_multifeature_to_data_hashes(feat_id, mode=data_type) exp._add_multifeature_to_stats(feat_id, mode=data_type) pair_stats = DEFAULT_STATS.copy() current_data_hash = exp._data_hashes[data_type][feat_id][cell_id] if stats_not_empty(pair_stats, current_data_hash, stage=1): precomputed_mask_stage1[i, j] = 0 if stats_not_empty(pair_stats, current_data_hash, stage=2): precomputed_mask_stage2[i, j] = 0 combined_precomputed_mask = np.ones((n, f)) if mode in ["stage2", "two_stage"]: combined_precomputed_mask[ np.where((precomputed_mask_stage1 == 0) & (precomputed_mask_stage2 == 0)) ] = 0 elif mode == "stage1": combined_precomputed_mask[np.where(precomputed_mask_stage1 == 0)] = 0 else: raise ValueError("Wrong mode!") # Temporary naming: Save original names and assign temporary names if missing # This ensures all TimeSeries have names for FFT cache keys original_signal_names = [sig.name if hasattr(sig, 'name') else None for sig in signals] original_feat_names = [feat.name if hasattr(feat, 'name') else None for feat in feats] # Assign temporary names to signals if missing for i, sig in enumerate(signals): if not hasattr(sig, 'name') or sig.name is None or sig.name == '': sig.name = f"_neuron_{cell_ids[i]}" # Assign temporary names to feats if missing for j, feat in enumerate(feats): if not hasattr(feat, 'name') or feat.name is None or feat.name == '': if isinstance(feat_ids[j], tuple): feat.name = f"_mts_{'_'.join(str(f) for f in feat_ids[j])}" else: feat.name = f"_feat_{feat_ids[j]}" try: computed_stats, computed_significance, info = compute_me_stats( signals, feats, mode=mode, names1=cell_ids, names2=feat_ids, metric=metric, mi_estimator=mi_estimator, mi_estimator_kwargs=mi_estimator_kwargs, precomputed_mask_stage1=precomputed_mask_stage1, precomputed_mask_stage2=precomputed_mask_stage2, n_shuffles_stage1=n_shuffles_stage1, n_shuffles_stage2=n_shuffles_stage2, metric_distr_type=metric_distr_type, noise_ampl=noise_ampl, ds=ds, topk1=topk1, topk2=topk2, multicomp_correction=multicomp_correction, pval_thr=pval_thr, find_optimal_delays=find_optimal_delays, skip_delays=[feat_ids.index(f) for f in skip_delays if f in feat_ids] if skip_delays else None, shift_window=int(shift_window * exp.fps), verbose=verbose, enable_parallelization=enable_parallelization, n_jobs=n_jobs, seed=seed, duplicate_behavior=duplicate_behavior, engine=engine, store_random_shifts=store_random_shifts, profile=profile, ) exp.optimal_nf_delays = info["optimal_delays"] # add hash data and update Experiment saved statistics and significance if needed for i, cell_id in enumerate(cell_ids): for j, feat_id in enumerate(feat_ids): # Check for non-existing feature if use_precomputed_stats==False if not use_precomputed_stats: if feat_id not in exp._data_hashes[data_type]: raise ValueError( f"Feature '{feat_id}' not found in data hashes. This may indicate the feature was not properly initialized." ) computed_stats[cell_id][feat_id]["data_hash"] = exp._data_hashes[data_type][feat_id][ cell_id ] me_val = computed_stats[cell_id][feat_id].get("me") if me_val is not None and metric == "mi": feat_entropy = exp.get_feature_entropy(feat_id, ds=ds) # Get entropy from appropriate data type if data_type == "calcium": neural_entropy = exp.neurons[int(cell_id)].ca.get_entropy(ds=ds) elif data_type == "spikes": neural_entropy = exp.neurons[int(cell_id)].sp.get_entropy(ds=ds) computed_stats[cell_id][feat_id]["rel_me_beh"] = me_val / feat_entropy computed_stats[cell_id][feat_id]["rel_me_ca"] = me_val / neural_entropy # Compute signal_ratio for binary discrete features feat_ts = feats[j] if (hasattr(feat_ts, 'int_data') and feat_ts.int_data is not None and len(np.unique(feat_ts.int_data)) == 2): opt_delay = int(info["optimal_delays"][i, j]) ca_data = signals[i].data feat_binary = np.roll(feat_ts.int_data, opt_delay)[:len(ca_data)] # Normalize binary feature to 0/1 unique_vals = np.unique(feat_binary) feat_01 = (feat_binary == unique_vals[1]).astype(float) computed_stats[cell_id][feat_id]["signal_ratio"] = calc_signal_ratio( feat_01, ca_data ) elif (isinstance(feat_ts, TimeSeries) and getattr(feat_ts, 'type_info', None) is not None and feat_ts.type_info.subtype == 'linear'): opt_delay = int(info["optimal_delays"][i, j]) ca_data = signals[i].data feat_data = np.roll(feat_ts.data, opt_delay)[:len(ca_data)] median_val = np.median(feat_data) feat_binary = (feat_data > median_val).astype(float) computed_stats[cell_id][feat_id]["signal_ratio"] = calc_signal_ratio( feat_binary, ca_data ) else: computed_stats[cell_id][feat_id]["signal_ratio"] = None if save_computed_stats: stage2_only = True if mode == "stage2" else False if combined_precomputed_mask[i, j]: exp.update_neuron_feature_pair_stats( computed_stats[cell_id][feat_id], cell_id, feat_id, mode=data_type, force_update=force_update, stage2_only=stage2_only, ) sig = computed_significance[cell_id][feat_id] exp.update_neuron_feature_pair_significance( sig, cell_id, feat_id, mode=data_type ) # Remove anti-selective pairs from significance (binary + linear continuous) if remove_anti_selective: n_removed = 0 for i, cell_id in enumerate(cell_ids): for j, feat_id in enumerate(feat_ids): sr = computed_stats[cell_id][feat_id].get("signal_ratio") if sr is not None and sr <= 1.0: if computed_significance[cell_id][feat_id].get("stage2", False): computed_significance[cell_id][feat_id]["stage2"] = False n_removed += 1 if save_computed_stats: exp.update_neuron_feature_pair_significance( computed_significance[cell_id][feat_id], cell_id, feat_id, mode=data_type, ) if verbose and n_removed > 0: print(f" Anti-selective pairs removed: {n_removed}") # save all results to a single object intense_params = { "neurons": {i: cell_ids[i] for i in range(len(cell_ids))}, "feat_bunch": {i: feat_ids[i] for i in range(len(feat_ids))}, "data_type": data_type, "mode": mode, "metric": metric, "n_shuffles_stage1": n_shuffles_stage1, "n_shuffles_stage2": n_shuffles_stage2, "metric_distr_type": metric_distr_type, "noise_ampl": noise_ampl, "ds": ds, "topk1": topk1, "topk2": topk2, "multicomp_correction": multicomp_correction, "pval_thr": pval_thr, "find_optimal_delays": find_optimal_delays, "shift_window": shift_window, } intense_res = IntenseResults() intense_res.update("stats", computed_stats) intense_res.update("significance", computed_significance) intense_res.update("info", info) intense_res.update("intense_params", intense_params) # Perform disentanglement analysis if requested if with_disentanglement: if verbose: print("\nPerforming mixed selectivity disentanglement analysis...") if profile: disentangle_start = time.perf_counter() # Step 1: Compute feature-feature significance # Capture similarity_matrix (MI values) for disentanglement optimization if verbose: print(" Step 1: Computing feature-feature significance...") feat_feat_similarity, feat_feat_significance, _, feat_names, disent_info = compute_feat_feat_significance( exp, feat_bunch=feat_ids, metric=metric, mode=mode, n_shuffles_stage1=n_shuffles_stage1, n_shuffles_stage2=n_shuffles_stage2, metric_distr_type=metric_distr_type, noise_ampl=noise_ampl, ds=ds, topk1=topk1, topk2=topk2, multicomp_correction=multicomp_correction, pval_thr=feat_feat_pval_thr, verbose=verbose, enable_parallelization=enable_parallelization, n_jobs=n_jobs, seed=seed, engine=engine, profile=profile, ) if profile: info['timings']['disentanglement_feat_feat'] = time.perf_counter() - disentangle_start disent_analysis_start = time.perf_counter() # Step 2: Use default multifeature map if not provided if multifeature_map is None: multifeature_map = DEFAULT_MULTIFEATURE_MAP # Step 3: Run disentanglement analysis # Pass pre-computed MI values to avoid redundant computation: # - cell_feat_stats: MI(neuron, feature) from INTENSE analysis # - feat_feat_similarity: MI(feature1, feature2) from feat-feat analysis if verbose: print(" Step 2: Computing neuron-pair disentanglement...") disent_results = disentangle_all_selectivities( exp, feat_names, ds=ds, multifeature_map=multifeature_map, feat_feat_significance=feat_feat_significance, cell_bunch=cell_ids, cell_feat_stats=computed_stats, feat_feat_similarity=feat_feat_similarity, n_jobs=n_jobs, pre_filter_func=pre_filter_func, post_filter_func=post_filter_func, filter_kwargs=filter_kwargs, ) disent_matrix = disent_results['disent_matrix'] count_matrix = disent_results['count_matrix'] per_neuron_disent = disent_results['per_neuron_disent'] if profile: info['timings']['disentanglement_analysis'] = time.perf_counter() - disent_analysis_start # Step 4: Get summary statistics from .disentanglement import get_disentanglement_summary summary = get_disentanglement_summary( disent_matrix, count_matrix, feat_names, feat_feat_significance, per_neuron_disent=per_neuron_disent, ) # Package disentanglement results disentanglement_results = { "feat_feat_similarity": feat_feat_similarity, "feat_feat_significance": feat_feat_significance, "disent_matrix": disent_matrix, "count_matrix": count_matrix, "per_neuron_disent": per_neuron_disent, "feature_names": feat_names, "summary": summary, } # Add to IntenseResults intense_res.update("disentanglement", disentanglement_results) if verbose: print("\nDisentanglement analysis complete!") if summary.get("overall_stats"): print( f"Total mixed selectivity pairs analyzed: {summary['overall_stats']['total_neuron_pairs']}" ) if "redundancy_rate" in summary["overall_stats"]: print(f"Redundancy rate: {summary['overall_stats']['redundancy_rate']:.1f}%") if "independence_rate" in summary["overall_stats"]: print( f"Independence rate: {summary['overall_stats']['independence_rate']:.1f}%" ) if "true_mixed_selectivity_rate" in summary["overall_stats"]: print( f"True mixed selectivity rate: {summary['overall_stats']['true_mixed_selectivity_rate']:.1f}%" ) else: print("No mixed selectivity pairs found in the selected neurons.") if profile: info['timings']['disentanglement'] = time.perf_counter() - disentangle_start info['timings']['total'] += info['timings']['disentanglement'] # Aggregate disentanglement FFT type counts if 'timings' in disent_info and 'fft_type_counts' in disent_info['timings']: disent_counts = disent_info['timings']['fft_type_counts'] info['timings']['fft_type_counts_disentanglement'] = disent_counts # Return with disentanglement results return ( computed_stats, computed_significance, info, intense_res, disentanglement_results, ) return computed_stats, computed_significance, info, intense_res, None finally: # Restore original names to leave objects unchanged for i, sig in enumerate(signals): sig.name = original_signal_names[i] for j, feat in enumerate(feats): feat.name = original_feat_names[j]
[docs] def compute_feat_feat_significance( exp, feat_bunch="all", metric="mi", mi_estimator="gcmi", mi_estimator_kwargs=None, mode="two_stage", n_shuffles_stage1=100, n_shuffles_stage2=1000, metric_distr_type=DEFAULT_METRIC_DISTR_TYPE, noise_ampl=1e-3, ds=1, topk1=1, topk2=5, multicomp_correction="holm", pval_thr=0.01, verbose=True, enable_parallelization=True, n_jobs=-1, seed=42, duplicate_behavior="ignore", engine="auto", profile=False, # FUTURE: Add save_computed_stats=True, use_precomputed_stats=True parameters # to enable caching of feat-feat results in experiment object similar to cell-feat ) -> tuple: """ Compute pairwise significance between all behavioral features. This function calculates pairwise similarity (e.g., mutual information) between all behavioral features using the two-stage INTENSE approach. The diagonal elements are set to zero as self-similarity is prevented by the check_for_coincidence mechanism in get_mi. Parameters ---------- exp : Experiment Experiment object containing behavioral data. feat_bunch : str, list or None Feature names to analyze. Default: 'all' (all features including multifeatures). Can be a list of specific feature names. metric : str, optional Similarity metric to use. Default: 'mi' (mutual information). mi_estimator : str, optional Mutual information estimator to use when metric='mi'. Default: 'gcmi'. Options: 'gcmi' or 'ksg' mi_estimator_kwargs : dict, optional Additional keyword arguments passed to the MI estimator function. mode : str, optional Computation mode: 'two_stage', 'stage1', or 'stage2'. Default: 'two_stage'. n_shuffles_stage1 : int, optional Number of shuffles for stage 1. Default: 100. n_shuffles_stage2 : int, optional Number of shuffles for stage 2. Default: 1000. metric_distr_type : str, optional Distribution type for metric null distribution ('gamma_zi', 'gamma', etc.). Default: 'gamma_zi'. noise_ampl : float, optional Small noise amplitude for numerical stability. Default: 1e-3. ds : int, optional Downsampling factor. Default: 1. topk1 : int, optional Top-k criterion for stage 1. Default: 1. topk2 : int, optional Top-k criterion for stage 2. Default: 5. multicomp_correction : str or None, optional Multiple comparison correction method. Default: 'holm'. pval_thr : float, optional P-value threshold for significance. Default: 0.01. verbose : bool, optional Whether to print progress information. Default: True. enable_parallelization : bool, optional Whether to use parallel processing. Default: True. n_jobs : int, optional Number of parallel jobs. -1 means use all processors. Default: -1. seed : int, optional Random seed for reproducibility. Default: 42. duplicate_behavior : str, optional How to handle duplicate TimeSeries in ts_bunch1 or ts_bunch2. - 'ignore': Process duplicates normally (default) - 'raise': Raise an error if duplicates are found - 'warn': Print a warning but continue processing Default: 'ignore'. engine : str, optional Computation engine for MI calculation: - 'auto': Automatically select FFT when beneficial (default) - 'fft': Force FFT-based computation - 'loop': Force loop-based computation (useful for comparison/debugging) Default: 'auto'. profile : bool, optional If True, collect timing and FFT type information. Default: False. When enabled, info['timings'] will contain: - 'stage1_pair_scanning': Stage 1 scanning time (seconds) - 'stage2_pair_scanning': Stage 2 scanning time if applicable (seconds) - 'fft_type_counts': Dictionary of FFT type usage counts - 'matrix_construction': Symmetric matrix construction time (seconds) - 'total': Total execution time (seconds) Returns ------- similarity_matrix : ndarray Matrix of similarity values between features. Element [i,j] contains the similarity between feature i and feature j. Diagonal is zero. significance_matrix : ndarray Matrix of binary significance values. 1 indicates significant similarity. p_value_matrix : ndarray Matrix of p-values for each comparison. feature_names : list List of feature names corresponding to matrix indices. May include tuples for multifeatures (e.g., ('x', 'y')). info : dict Dictionary containing additional information from compute_me_stats. Notes ----- - Uses the two-stage INTENSE approach for efficient significance testing - Diagonal elements are zero (self-similarity check prevents computation) - The function handles both discrete and continuous variables - Supports MultiTimeSeries (e.g., place fields from x,y coordinates) - For mutual information, values are in bits - No optimal delay search is performed (delays are set to 0) Examples -------- >>> from driada.experiment.synthetic import generate_synthetic_exp >>> >>> # Create test experiment >>> exp = generate_synthetic_exp(n_dfeats=2, n_cfeats=2, nneurons=3, ... duration=60, fps=10, seed=42, verbose=False) >>> >>> # Compute feature-feature correlations >>> sim_mat, sig_mat, pval_mat, features, info = compute_feat_feat_significance( ... exp, ... mode='stage1', ... n_shuffles_stage1=10, ... verbose=False ... ) >>> sim_mat.shape == (4, 4) # 2 discrete + 2 continuous features True >>> np.allclose(np.diag(sim_mat), 0) # Diagonal is zero True >>> >>> # Analyze specific features only >>> sim_mat2, sig_mat2, pval_mat2, features2, info2 = compute_feat_feat_significance( ... exp, ... feat_bunch=['d_feat_0', 'd_feat_1'], ... mode='stage1', ... n_shuffles_stage1=10, ... verbose=False ... ) >>> sim_mat2.shape == (2, 2) True Raises ------ ValueError If features are not found in experiment Notes ----- - Only upper triangle is computed for efficiency (matrix is symmetric) - Diagonal elements are always zero (self-similarity prevented) - No delay optimization is performed between features - Supports both discrete and continuous features - Multifeatures are created using aggregate_multiple_ts - When called with circular features, ``feat_bunch`` should contain ``_2d``-substituted names (e.g., ``headdirection_2d``) to match the experiment's dynamic features after circular substitution.""" import numpy as np # Process feature bunch - default is all features if feat_bunch == "all": feat_bunch = None # None means all features in _process_fbunch feat_ids = exp._process_fbunch(feat_bunch, allow_multifeatures=True, mode="calcium") n_features = len(feat_ids) # Handle empty feature list case if n_features == 0: if verbose: print("No features to analyze - returning empty results") return ( np.array([]).reshape(0, 0), # similarity_matrix np.array([]).reshape(0, 0), # significance_matrix np.array([]).reshape(0, 0), # p_value_matrix [], # feature_names {}, # info ) if verbose: print(f"Computing behavioral similarity matrix for {n_features} features...") print(f"Features: {feat_ids}") # Note: computing only unique pairs (upper triangle), not all n² n_unique_pairs = n_features * (n_features - 1) // 2 print(f"Unique pairs to compute: {n_unique_pairs} (avoiding redundancy)") # Get TimeSeries/MultiTimeSeries objects for all features from ..information.info_base import aggregate_multiple_ts feature_ts = [] for feat_id in feat_ids: if isinstance(feat_id, tuple): # Create MultiTimeSeries for tuples using aggregate_multiple_ts ts_list = [exp.dynamic_features[f] for f in feat_id] ts = aggregate_multiple_ts(*ts_list) else: ts = exp.dynamic_features[feat_id] feature_ts.append(ts) # Create masks that exclude diagonal (self-comparisons) AND lower triangle # This ensures we only compute the upper triangle for symmetric results precomputed_mask_stage1 = np.triu(np.ones((n_features, n_features)), k=1) precomputed_mask_stage2 = np.triu(np.ones((n_features, n_features)), k=1) # Temporary naming: Save original names and assign temporary names if missing # This ensures all TimeSeries have names for FFT cache keys original_feat_names = [feat.name if hasattr(feat, 'name') else None for feat in feature_ts] # Assign temporary names to features if missing for j, feat in enumerate(feature_ts): if not hasattr(feat, 'name') or feat.name is None or feat.name == '': if isinstance(feat_ids[j], tuple): feat.name = f"_feat_{'_'.join(str(f) for f in feat_ids[j])}" else: feat.name = f"_feat_{feat_ids[j]}" try: # Call compute_me_stats with features against themselves # Note: optimal delays are disabled (set to False) stats, significance, info = compute_me_stats( feature_ts, feature_ts, names1=feat_ids, names2=feat_ids, metric=metric, mi_estimator=mi_estimator, mi_estimator_kwargs=mi_estimator_kwargs, mode=mode, precomputed_mask_stage1=precomputed_mask_stage1, precomputed_mask_stage2=precomputed_mask_stage2, n_shuffles_stage1=n_shuffles_stage1, n_shuffles_stage2=n_shuffles_stage2, metric_distr_type=metric_distr_type, noise_ampl=noise_ampl, ds=ds, topk1=topk1, topk2=topk2, multicomp_correction=multicomp_correction, pval_thr=pval_thr, find_optimal_delays=False, # No delay optimization shift_window=0, # No shift window needed verbose=verbose, enable_parallelization=enable_parallelization, n_jobs=n_jobs, seed=seed, duplicate_behavior="ignore", # Default behavior for feature-feature comparison engine=engine, profile=profile, ) # Extract matrices from results if profile: matrix_start = time.perf_counter() similarity_matrix = np.zeros((n_features, n_features)) significance_matrix = np.zeros((n_features, n_features)) p_value_matrix = np.ones((n_features, n_features)) # Fill matrices from stats and significance dictionaries # Since we only computed upper triangle, we need to fill both upper and lower for i, feat1 in enumerate(feat_ids): for j, feat2 in enumerate(feat_ids): if i == j: # Diagonal is already 0 continue # Convert tuples to strings for dictionary keys if needed key1 = str(feat1) if isinstance(feat1, tuple) else feat1 key2 = str(feat2) if isinstance(feat2, tuple) else feat2 # We computed only upper triangle, so check if this pair was computed if i < j: # Upper triangle - get from stats if key1 in stats and key2 in stats[key1]: stats_dict = stats[key1][key2] if stats_dict: # Check if dict is not empty similarity_matrix[i, j] = stats_dict.get("me", 0) p_value_matrix[i, j] = stats_dict.get("pval", 1) sig_dict = significance.get(key1, {}).get(key2, {}) if sig_dict.get("stage2") is not None: significance_matrix[i, j] = float(sig_dict["stage2"]) elif sig_dict.get("stage1") is not None: significance_matrix[i, j] = float(sig_dict["stage1"]) else: # Lower triangle - copy from upper triangle for symmetry similarity_matrix[i, j] = similarity_matrix[j, i] p_value_matrix[i, j] = p_value_matrix[j, i] significance_matrix[i, j] = significance_matrix[j, i] # Ensure diagonal is zero (should already be due to coincidence check) np.fill_diagonal(similarity_matrix, 0) np.fill_diagonal(significance_matrix, 0) np.fill_diagonal(p_value_matrix, 1) if profile: info['timings']['matrix_construction'] = time.perf_counter() - matrix_start info['timings']['total'] += info['timings']['matrix_construction'] if verbose: print("\nBehavioral similarity matrix computation complete!") print(f"Feature pairs analyzed: {n_features * n_features}") print(f"Significant pairs (stage 1): {info.get('n_significant_stage1', 0)}") print(f"Significant pairs (final): {np.sum(significance_matrix)}") # Count unique significant pairs (upper triangle only) unique_sig = np.sum(np.triu(significance_matrix, k=1)) print(f"Unique significant pairs: {unique_sig}") return similarity_matrix, significance_matrix, p_value_matrix, feat_ids, info finally: # Restore original names to leave objects unchanged for j, feat in enumerate(feature_ts): feat.name = original_feat_names[j]
[docs] def compute_cell_cell_significance( exp, cell_bunch=None, data_type="calcium", metric="mi", mi_estimator="gcmi", mi_estimator_kwargs=None, mode="two_stage", n_shuffles_stage1=100, n_shuffles_stage2=1000, metric_distr_type=DEFAULT_METRIC_DISTR_TYPE, noise_ampl=1e-3, ds=1, topk1=1, topk2=5, multicomp_correction="holm", pval_thr=0.01, verbose=True, enable_parallelization=True, n_jobs=-1, seed=42, duplicate_behavior="ignore", profile=False, # FUTURE: Add save_computed_stats=True, use_precomputed_stats=True parameters # to enable caching of cell-cell results in experiment object similar to cell-feat ) -> tuple: """ Compute pairwise functional correlations between neurons using INTENSE. This function calculates pairwise similarity (e.g., mutual information) between all neurons using the two-stage INTENSE approach. This can reveal functionally correlated neurons that may form assemblies or functional modules. Parameters ---------- exp : Experiment Experiment object containing neural data. cell_bunch : int, list or None, optional Neuron indices to analyze. Default: None (all neurons). data_type : str, optional Type of neural data: 'calcium' or 'spikes'. Default: 'calcium'. metric : str, optional Similarity metric to use. Default: 'mi' (mutual information). mi_estimator : str, optional Mutual information estimator to use when metric='mi'. Default: 'gcmi'. Options: 'gcmi' or 'ksg' mi_estimator_kwargs : dict, optional Additional keyword arguments passed to the MI estimator function. mode : str, optional Computation mode: 'two_stage', 'stage1', or 'stage2'. Default: 'two_stage'. n_shuffles_stage1 : int, optional Number of shuffles for stage 1. Default: 100. n_shuffles_stage2 : int, optional Number of shuffles for stage 2. Default: 1000. metric_distr_type : str, optional Distribution type for metric null distribution ('gamma_zi', 'gamma', etc.). Default: 'gamma_zi'. noise_ampl : float, optional Small noise amplitude for numerical stability. Default: 1e-3. ds : int, optional Downsampling factor. Default: 1. topk1 : int, optional Top-k criterion for stage 1. Default: 1. topk2 : int, optional Top-k criterion for stage 2. Default: 5. multicomp_correction : str or None, optional Multiple comparison correction method. Default: 'holm'. pval_thr : float, optional P-value threshold for significance. Default: 0.01. verbose : bool, optional Whether to print progress information. Default: True. enable_parallelization : bool, optional Whether to use parallel processing. Default: True. n_jobs : int, optional Number of parallel jobs. -1 means use all processors. Default: -1. seed : int, optional Random seed for reproducibility. Default: 42. duplicate_behavior : str, optional How to handle duplicate TimeSeries in ts_bunch1 or ts_bunch2. - 'ignore': Process duplicates normally (default) - 'raise': Raise an error if duplicates are found - 'warn': Print a warning but continue processing Default: 'ignore'. profile : bool, optional If True, collect timing and FFT type information. Default: False. When enabled, info['timings'] will contain: - 'stage1_pair_scanning': Stage 1 scanning time (seconds) - 'stage2_pair_scanning': Stage 2 scanning time if applicable (seconds) - 'fft_type_counts': Dictionary of FFT type usage counts - 'matrix_construction': Symmetric matrix construction time (seconds) - 'total': Total execution time (seconds) Returns ------- similarity_matrix : ndarray Matrix of similarity values between neurons. Element [i,j] contains the similarity between neuron i and neuron j. Diagonal is zero. significance_matrix : ndarray Matrix of binary significance values. 1 indicates significant similarity. p_value_matrix : ndarray Matrix of p-values for each comparison. cell_ids : list List of cell IDs corresponding to matrix indices. info : dict Dictionary containing additional information from compute_me_stats. Notes ----- - Uses the two-stage INTENSE approach for efficient significance testing - Diagonal elements are zero (self-similarity check prevents computation) - For calcium imaging data, considers temporal dynamics - For spike data, uses discrete MI formulation - Can identify functional assemblies through graph analysis of significant pairs - No optimal delay search is performed (synchronous activity assumed) Examples -------- >>> from driada.experiment.synthetic import generate_synthetic_exp >>> from driada.information.info_base import TimeSeries >>> import numpy as np >>> >>> # Create experiment with correlated neurons >>> exp = generate_synthetic_exp(n_dfeats=1, n_cfeats=1, nneurons=3, ... duration=60, fps=10, seed=42, verbose=False) >>> >>> # Make neurons 0 and 1 correlated >>> noise = np.random.RandomState(42).randn(len(exp.neurons[0].ca.data)) * 0.1 >>> exp.neurons[1].ca = TimeSeries( ... exp.neurons[0].ca.data + noise, discrete=False ... ) >>> >>> # Compute neuron-neuron correlations >>> sim_mat, sig_mat, pval_mat, cells, info = compute_cell_cell_significance( ... exp, ... cell_bunch=[0, 1, 2], ... mode='stage1', ... n_shuffles_stage1=10, ... verbose=False ... ) >>> sim_mat.shape == (3, 3) True >>> np.allclose(np.diag(sim_mat), 0) # Self-correlation is zero True >>> sim_mat[0, 1] > sim_mat[0, 2] # Neurons 0,1 more correlated than 0,2 True Raises ------ ValueError If data_type is not 'calcium' or 'spikes' If spike data is missing for requested neurons Notes ----- - Only upper triangle is computed for efficiency (matrix is symmetric) - Warns if all neurons have identical spike data - Computes network statistics when verbose=True - Synchronous activity assumed (no delay optimization)""" import numpy as np # Check downsampling exp.check_ds(ds) # Process cell bunch cell_ids = exp._process_cbunch(cell_bunch) n_cells = len(cell_ids) cells = [exp.neurons[cell_id] for cell_id in cell_ids] if verbose: print(f"Computing neuronal similarity matrix for {n_cells} neurons...") print(f"Data type: {data_type}") # Note: computing only unique pairs (upper triangle), not all n² n_unique_pairs = n_cells * (n_cells - 1) // 2 print(f"Unique pairs to compute: {n_unique_pairs} (avoiding redundancy)") # Get neural signals based on data type if data_type == "calcium": signals = [cell.ca for cell in cells] elif data_type == "spikes": signals = [cell.sp for cell in cells] # Check if spike data exists and is non-degenerate if any(sig is None for sig in signals): raise ValueError( "Some neurons have no spike data. Use reconstruct_spikes or provide spike data." ) # Check if all spike data is identical (e.g., all zeros) if len(signals) > 1: first_data = signals[0].data if all(np.array_equal(sig.data, first_data) for sig in signals[1:]): import warnings warnings.warn( "All neurons have identical spike data. This may lead to degenerate results." ) else: raise ValueError('"data_type" can be either "calcium" or "spikes"') # Create masks that exclude diagonal (self-comparisons) AND lower triangle # This ensures we only compute the upper triangle for symmetric results precomputed_mask_stage1 = np.triu(np.ones((n_cells, n_cells)), k=1) precomputed_mask_stage2 = np.triu(np.ones((n_cells, n_cells)), k=1) # Temporary naming: Save original names and assign temporary names if missing # This ensures all TimeSeries have names for FFT cache keys original_signal_names = [sig.name if hasattr(sig, 'name') else None for sig in signals] # Assign temporary names to signals if missing for i, sig in enumerate(signals): if not hasattr(sig, 'name') or sig.name is None or sig.name == '': sig.name = f"_neuron_{cell_ids[i]}" try: # Call compute_me_stats with neurons against themselves # Note: optimal delays are disabled (set to False) for synchronous analysis stats, significance, info = compute_me_stats( signals, signals, names1=cell_ids, names2=cell_ids, metric=metric, mi_estimator=mi_estimator, mi_estimator_kwargs=mi_estimator_kwargs, mode=mode, precomputed_mask_stage1=precomputed_mask_stage1, precomputed_mask_stage2=precomputed_mask_stage2, n_shuffles_stage1=n_shuffles_stage1, n_shuffles_stage2=n_shuffles_stage2, metric_distr_type=metric_distr_type, noise_ampl=noise_ampl, ds=ds, topk1=topk1, topk2=topk2, multicomp_correction=multicomp_correction, pval_thr=pval_thr, find_optimal_delays=False, # Assume synchronous activity shift_window=0, # No shift window needed verbose=verbose, enable_parallelization=enable_parallelization, n_jobs=n_jobs, seed=seed, duplicate_behavior="ignore", # Default behavior for cell-cell comparison engine="auto", # FFT optimization when applicable profile=profile, ) # Extract matrices from results if profile: matrix_start = time.perf_counter() similarity_matrix = np.zeros((n_cells, n_cells)) significance_matrix = np.zeros((n_cells, n_cells)) p_value_matrix = np.ones((n_cells, n_cells)) # Fill matrices from stats and significance dictionaries # Since we only computed upper triangle, we need to fill both upper and lower for i, cell1 in enumerate(cell_ids): for j, cell2 in enumerate(cell_ids): if i == j: # Diagonal is already 0 continue # We computed only upper triangle, so check if this pair was computed if i < j: # Upper triangle - get from stats if cell1 in stats and cell2 in stats[cell1]: stats_dict = stats[cell1][cell2] if stats_dict: # Check if dict is not empty similarity_matrix[i, j] = stats_dict.get("me", 0) p_value_matrix[i, j] = stats_dict.get("pval", 1) sig_dict = significance.get(cell1, {}).get(cell2, {}) if sig_dict.get("stage2") is not None: significance_matrix[i, j] = float(sig_dict["stage2"]) elif sig_dict.get("stage1") is not None: significance_matrix[i, j] = float(sig_dict["stage1"]) else: # Lower triangle - copy from upper triangle for symmetry similarity_matrix[i, j] = similarity_matrix[j, i] p_value_matrix[i, j] = p_value_matrix[j, i] significance_matrix[i, j] = significance_matrix[j, i] # Ensure diagonal is zero (should already be due to coincidence check) np.fill_diagonal(similarity_matrix, 0) np.fill_diagonal(significance_matrix, 0) np.fill_diagonal(p_value_matrix, 1) if profile: info['timings']['matrix_construction'] = time.perf_counter() - matrix_start info['timings']['total'] += info['timings']['matrix_construction'] if verbose: print("\nNeuronal similarity matrix computation complete!") print(f"Neuron pairs analyzed: {n_cells * n_cells}") print(f"Significant pairs (stage 1): {info.get('n_significant_stage1', 0)}") print(f"Significant pairs (final): {np.sum(significance_matrix)}") # Count unique significant pairs (upper triangle only) unique_sig = np.sum(np.triu(significance_matrix, k=1)) print(f"Unique significant pairs: {unique_sig}") # Basic network statistics if unique_sig > 0: avg_connections = np.sum(significance_matrix) / n_cells print(f"Average connections per neuron: {avg_connections:.2f}") max_connections = np.max(np.sum(significance_matrix, axis=1)) print(f"Maximum connections for a single neuron: {int(max_connections)}") return similarity_matrix, significance_matrix, p_value_matrix, cell_ids, info finally: # Restore original names to leave objects unchanged for i, sig in enumerate(signals): sig.name = original_signal_names[i]
[docs] def compute_embedding_selectivity( exp, embedding_methods=None, cell_bunch=None, data_type="calcium", metric="mi", mi_estimator="gcmi", mi_estimator_kwargs=None, mode="two_stage", n_shuffles_stage1=100, n_shuffles_stage2=10000, metric_distr_type=DEFAULT_METRIC_DISTR_TYPE, noise_ampl=1e-3, ds=1, use_precomputed_stats=True, save_computed_stats=True, force_update=False, topk1=1, topk2=5, multicomp_correction="holm", pval_thr=0.01, find_optimal_delays=True, shift_window=2, remove_anti_selective=True, verbose=True, enable_parallelization=True, n_jobs=-1, seed=42, ) -> dict: """ Compute INTENSE selectivity between neurons and dimensionality reduction embeddings. This function treats each embedding component as a dynamic feature and computes the mutual information between neural activity and embedding dimensions. This reveals how individual neurons contribute to the population-level manifold structure. Parameters ---------- exp : Experiment Experiment object with stored embeddings embedding_methods : str, list or None Names of embedding methods to analyze. If None, analyzes all stored embeddings. cell_bunch : int, iterable or None Neuron indices. By default (None), all neurons will be taken data_type : str Data type used for embeddings and INTENSE ('calcium' or 'spikes') metric : str Similarity metric between TimeSeries (default: 'mi') mi_estimator : str Mutual information estimator to use when metric='mi'. Default: 'gcmi'. Options: 'gcmi' or 'ksg' mi_estimator_kwargs : dict, optional Additional keyword arguments passed to the MI estimator function. mode : str Computation mode: 'stage1', 'stage2', or 'two_stage' (default) n_shuffles_stage1 : int Number of shuffles for first stage (default: 100) n_shuffles_stage2 : int Number of shuffles for second stage (default: 10000) metric_distr_type : str Distribution type for shuffled metric distribution fit (default: 'norm') noise_ampl : float Small noise amplitude added to improve numerical fit (default: 1e-3) ds : int Downsampling constant (default: 1) use_precomputed_stats : bool Whether to use stats saved in Experiment instance (default: True) save_computed_stats : bool Whether to save computed stats to Experiment instance (default: True) force_update : bool Force update saved statistics if data hash collision found (default: False) topk1 : int True MI for stage 1 should be among topk1 MI shuffles (default: 1) topk2 : int True MI for stage 2 should be among topk2 MI shuffles (default: 5) multicomp_correction : str or None Multiple comparison correction type: None, 'bonferroni', or 'holm' (default) pval_thr : float P-value threshold (default: 0.01) find_optimal_delays : bool Find optimal temporal delays between neural activity and embeddings (default: True) shift_window : int Window for optimal shift search in seconds (default: 2) verbose : bool Print progress information (default: True) enable_parallelization : bool Enable parallel computation (default: True) n_jobs : int Number of parallel jobs, -1 for all cores (default: -1) seed : int Random seed (default: 42) Returns ------- results : dict Dictionary with keys as embedding method names, each containing: - 'stats': Statistics for each neuron-component pair - 'significance': Significance results - 'info': Additional information from compute_me_stats - 'intense_results': Full IntenseResults object from INTENSE computation - 'significant_neurons': Dict of neurons significantly selective to embedding components - 'n_components': Number of embedding components - 'component_selectivity': For each component, list of selective neurons Raises ------ ValueError If no embeddings found for specified data_type If embedding method not found Notes ----- - Temporarily adds embedding components as dynamic features - Forces use_precomputed_stats=False for temporary features - Component names follow pattern "{method}_comp{index}" - Cleanup in finally block ensures experiment state restored - Only stage2 significance is considered for results Examples -------- >>> from driada.experiment.synthetic import generate_synthetic_exp >>> from sklearn.decomposition import PCA >>> import numpy as np >>> >>> # Create experiment >>> exp = generate_synthetic_exp(n_dfeats=1, n_cfeats=1, nneurons=5, ... duration=60, fps=10, seed=42, verbose=False) >>> >>> # Create and store PCA embedding >>> neural_data = np.array([exp.neurons[i].ca.data for i in range(5)]).T >>> pca = PCA(n_components=2, random_state=42) >>> embedding = pca.fit_transform(neural_data) >>> exp.store_embedding(embedding, method_name='pca', data_type='calcium') >>> >>> # Compute embedding selectivity >>> results = compute_embedding_selectivity( ... exp, ... embedding_methods=['pca'], ... cell_bunch=[0, 1, 2], ... mode='stage1', ... n_shuffles_stage1=10, ... verbose=False ... ) # doctest: +ELLIPSIS ... >>> >>> 'pca' in results True >>> results['pca']['n_components'] 2 >>> 'component_selectivity' in results['pca'] True See Also -------- compute_cell_feat_significance : Compute selectivity for behavioral features ~driada.integration.manifold_analysis.get_functional_organization : Analyze organization in embeddings ~driada.integration.manifold_analysis.compare_embeddings : Compare multiple embedding methods""" # Get list of embedding methods to analyze if embedding_methods is None: embedding_methods = list(exp.embeddings[data_type].keys()) elif isinstance(embedding_methods, str): embedding_methods = [embedding_methods] if not embedding_methods: raise ValueError( f"No embeddings found for data_type '{data_type}'. " "Use exp.store_embedding() to add embeddings first." ) results = {} # Process each embedding method for method_name in embedding_methods: if verbose: print(f"\n{'='*60}") print(f"Computing selectivity for embedding: {method_name}") print(f"{'='*60}") # Get embedding data embedding_dict = exp.get_embedding(method_name, data_type) embedding_data = embedding_dict["data"] n_components = embedding_data.shape[1] # Handle downsampled embeddings: interpolate to match calcium length n_frames = exp.n_frames if embedding_data.shape[0] != n_frames: from scipy.interpolate import interp1d x_ds = np.linspace(0, 1, embedding_data.shape[0]) x_full = np.linspace(0, 1, n_frames) embedding_data = interp1d(x_ds, embedding_data, axis=0, kind="linear")(x_full) # Create TimeSeries for each embedding component embedding_features = {} for comp_idx in range(n_components): feat_name = f"{method_name}_comp{comp_idx}" embedding_features[feat_name] = TimeSeries(embedding_data[:, comp_idx], discrete=False, name=feat_name) # Temporarily add embedding components to dynamic features original_features = exp.dynamic_features.copy() exp.dynamic_features.update(embedding_features) # Also update internal experiment attributes for the new features for feat_name, feat_ts in embedding_features.items(): setattr(exp, feat_name, feat_ts) # Rebuild data hashes to include new features exp._build_data_hashes(mode=data_type) # Initialize stats tables if not already done if save_computed_stats and data_type not in exp.stats_tables: exp._set_selectivity_tables(data_type) try: # Run INTENSE analysis stats, significance, info, intense_res, _ = compute_cell_feat_significance( exp, cell_bunch=cell_bunch, feat_bunch=list(embedding_features.keys()), data_type=data_type, metric=metric, mi_estimator=mi_estimator, mi_estimator_kwargs=mi_estimator_kwargs, mode=mode, n_shuffles_stage1=n_shuffles_stage1, n_shuffles_stage2=n_shuffles_stage2, metric_distr_type=metric_distr_type, noise_ampl=noise_ampl, ds=ds, use_precomputed_stats=False, # Must be False for new dynamic features save_computed_stats=False, # Don't save stats for temporary embedding features force_update=force_update, topk1=topk1, topk2=topk2, multicomp_correction=multicomp_correction, pval_thr=pval_thr, find_optimal_delays=find_optimal_delays, shift_window=shift_window, remove_anti_selective=remove_anti_selective, verbose=verbose, enable_parallelization=enable_parallelization, n_jobs=n_jobs, seed=seed, ) # Extract significant neurons from the significance results # Note: significance structure is significance[neuron_id][feat_name] significant_neurons = {} for neuron_id in significance.keys(): for feat_name in embedding_features.keys(): if feat_name in significance[neuron_id]: sig_info = significance[neuron_id][feat_name] if sig_info.get("stage2", False): # Check if significant in stage 2 if neuron_id not in significant_neurons: significant_neurons[neuron_id] = [] significant_neurons[neuron_id].append(feat_name) # Organize component selectivity component_selectivity = {comp_idx: [] for comp_idx in range(n_components)} for neuron_id, features in significant_neurons.items(): for feat in features: comp_idx = int(feat.split("_comp")[-1]) component_selectivity[comp_idx].append(neuron_id) # Store results results[method_name] = { "stats": stats, "significance": significance, "info": info, "intense_results": intense_res, # Include the full IntenseResults object "significant_neurons": significant_neurons, "n_components": n_components, "component_selectivity": component_selectivity, "embedding_metadata": embedding_dict.get("metadata", {}), } if verbose: n_sig_neurons = len(significant_neurons) n_total_neurons = len(exp._process_cbunch(cell_bunch)) print(f"\nResults for {method_name}:") print(f" Embedding dimensions: {n_components}") print( f" Significant neurons: {n_sig_neurons}/{n_total_neurons} ({100*n_sig_neurons/n_total_neurons:.1f}%)" ) # Component-wise summary for comp_idx in range(n_components): n_selective = len(component_selectivity[comp_idx]) if n_selective > 0: print(f" Component {comp_idx}: {n_selective} selective neurons") finally: # Restore original features exp.dynamic_features = original_features # Remove temporary attributes for feat_name in embedding_features.keys(): if hasattr(exp, feat_name): delattr(exp, feat_name) return results