"""
Mixed selectivity disentanglement analysis for INTENSE.
This module provides functions to analyze and disentangle mixed selectivity
in neural responses when neurons respond to multiple, potentially correlated
behavioral variables.
"""
import warnings
import numpy as np
from itertools import combinations
from joblib import Parallel, delayed
from ..information.info_base import get_mi, conditional_mi, MultiTimeSeries
from ..information.gcmi import cmi_ggg
import driada # For PARALLEL_BACKEND
from .intense_base import _parallel_executor
DEFAULT_MULTIFEATURE_MAP = {
("x", "y"): "place", # 2D spatial location multifeature
("x", "y", "z"): "3d-place", # 3D spatial location multifeature
}
"""Default multifeature mapping for common behavioral variable combinations.
Maps component tuples to their semantic names:
- ``("x", "y")``: mapped to ``"place"`` (2D spatial location)
- ``("x", "y", "z")``: mapped to ``"3d-place"`` (3D spatial location)
"""
# Epsilon tolerance for floating-point comparisons
# MI values are computed via numerical estimation and rarely equal exactly 0.0
MI_EPSILON = 1e-6
# Ratio threshold for "dominant" feature detection in synergy cases
# 2.0 means one feature has >2x the MI of the other, indicating strong dominance
DOMINANCE_RATIO_THRESHOLD = 2.0
# Threshold for CMI-ratio redundancy detection
# If conditioning on the other feature reduces MI below this fraction of the
# original, the feature is considered redundant (its information is mostly shared)
REDUNDANCY_CMI_RATIO = 0.1
# Valid disentanglement result values
VALID_DISRES_VALUES = (0, 0.5, 1)
def _flip_decision(decision):
"""Flip disentanglement decision: 0↔1, keep 0.5 unchanged.
Parameters
----------
decision : float
Original decision (0, 0.5, or 1).
Returns
-------
float
Flipped decision.
"""
return {0: 1, 1: 0, 0.5: 0.5}[decision]
def _downsample_copnorm(data, ds):
"""Downsample copula-normalized data along time axis.
Parameters
----------
data : ndarray
Copula-normalized data (1D for TimeSeries, 2D for MultiTimeSeries).
ds : int
Downsampling factor.
Returns
-------
ndarray
Downsampled data.
"""
if data.ndim == 1:
return data[::ds]
else:
return data[:, ::ds]
def _lookup_cell_feat_mi(cell_feat_stats, neuron_id, feat_name):
"""Look up MI(neuron, feature) from pre-computed stats dict.
Parameters
----------
cell_feat_stats : dict or None
Nested dictionary: stats[cell_id][feat_name]["me"] = MI value.
neuron_id : any
Neuron identifier (cell ID).
feat_name : str or tuple
Feature name/identifier.
Returns
-------
float or None
The pre-computed MI value, or None if not found.
"""
if cell_feat_stats is None:
return None
try:
return cell_feat_stats[neuron_id][feat_name].get("me")
except (KeyError, TypeError):
return None
def _lookup_cell_feat_delay(cell_feat_stats, neuron_id, feat_name):
"""Look up optimal delay for a (neuron, feature) pair.
Parameters
----------
cell_feat_stats : dict or None
Nested dictionary: stats[cell_id][feat_name]["opt_delay"] = delay.
neuron_id : any
Neuron identifier (cell ID).
feat_name : str or tuple
Feature name/identifier.
Returns
-------
int
The optimal delay in original frame units, or 0 if not found.
"""
if cell_feat_stats is None:
return 0
try:
return cell_feat_stats[neuron_id][feat_name].get("opt_delay", 0)
except (KeyError, TypeError):
return 0
def _lookup_feat_feat_mi(feat_feat_similarity, feat_names, feat1_name, feat2_name):
"""Look up MI(feature1, feature2) from pre-computed similarity matrix.
Parameters
----------
feat_feat_similarity : ndarray or None
Symmetric matrix where [i, j] = MI(feat_i, feat_j).
feat_names : list
List of feature names corresponding to matrix indices.
feat1_name : str or tuple
First feature name.
feat2_name : str or tuple
Second feature name.
Returns
-------
float or None
The pre-computed MI value, or None if not found.
"""
if feat_feat_similarity is None or feat_names is None:
return None
try:
ind1 = feat_names.index(feat1_name)
ind2 = feat_names.index(feat2_name)
return feat_feat_similarity[ind1, ind2]
except (ValueError, IndexError):
return None
def _disentangle_pair_with_precomputed(
ts1, ts2, ts3,
mi12=None,
mi13=None,
mi23=None,
ts1_copnorm=None,
ts2_copnorm=None,
ts3_copnorm=None,
verbose=False,
ds=1,
delay_feat1=0,
delay_feat2=0,
):
"""Disentangle with optional pre-computed MI values and copula data.
Internal function that performs disentanglement analysis, optionally
using pre-computed pairwise MI values and copula-normalized data to
avoid redundant computation.
Parameters
----------
ts1 : TimeSeries
Neural activity time series.
ts2 : TimeSeries
First behavioral variable.
ts3 : TimeSeries
Second behavioral variable.
mi12 : float or None, optional
Pre-computed MI(ts1, ts2). Computed if None.
mi13 : float or None, optional
Pre-computed MI(ts1, ts3). Computed if None.
mi23 : float or None, optional
Pre-computed MI(ts2, ts3). Computed if None.
ts1_copnorm : ndarray or None, optional
Pre-computed copula-normalized data for ts1 (downsampled).
ts2_copnorm : ndarray or None, optional
Pre-computed copula-normalized data for ts2 (downsampled).
ts3_copnorm : ndarray or None, optional
Pre-computed copula-normalized data for ts3 (downsampled).
verbose : bool, optional
If True, print detailed analysis results. Default: False.
ds : int, optional
Downsampling factor. Default: 1.
delay_feat1 : int, optional
Optimal delay for feature 1, in downsampled frame units.
Applied as ``np.roll`` shift to align with neural activity. Default: 0.
delay_feat2 : int, optional
Optimal delay for feature 2, in downsampled frame units.
Applied as ``np.roll`` shift to align with neural activity. Default: 0.
Returns
-------
float
Disentanglement result (0, 0.5, or 1).
"""
# Compute only missing pairwise MI values
if mi12 is None:
mi12 = get_mi(ts1, ts2, ds=ds, shift=delay_feat1)
if mi13 is None:
mi13 = get_mi(ts1, ts3, ds=ds, shift=delay_feat2)
if mi23 is None:
mi23 = get_mi(ts2, ts3, ds=ds)
# Compute conditional MI - use pre-computed copula data if all available (CCC case)
if ts1_copnorm is not None and ts2_copnorm is not None and ts3_copnorm is not None:
# Apply optimal delays to align features with neural activity
ts2_cp = np.roll(ts2_copnorm, delay_feat1, axis=-1) if delay_feat1 else ts2_copnorm
ts3_cp = np.roll(ts3_copnorm, delay_feat2, axis=-1) if delay_feat2 else ts3_copnorm
# Direct cmi_ggg call with pre-cached copula data (faster)
cmi123 = cmi_ggg(ts1_copnorm, ts2_cp, ts3_cp, biascorrect=True, demeaned=True)
cmi132 = cmi_ggg(ts1_copnorm, ts3_cp, ts2_cp, biascorrect=True, demeaned=True)
else:
# Fallback for mixed discrete/continuous cases
cmi123 = conditional_mi(ts1, ts2, ts3, ds=ds) # MI(neuron, behavior1 | behavior2)
cmi132 = conditional_mi(ts1, ts3, ts2, ds=ds) # MI(neuron, behavior2 | behavior1)
# Compute interaction information (average of two equivalent formulas)
I_av = np.mean([cmi123 - mi12, cmi132 - mi13])
if verbose:
print()
print("MI(A,X):", mi12)
print("MI(A,Y):", mi13)
print("MI(X,Y):", mi23)
print()
print("MI(A,X|Y):", cmi123)
print("MI(A,Y|X):", cmi132)
print()
print("MI(A,X|Y) / MI(A,X):", np.round(cmi123 / mi12, 3) if mi12 > 0 else "N/A")
print("MI(A,Y|X) / MI(A,Y):", np.round(cmi132 / mi13, 3) if mi13 > 0 else "N/A")
print()
print("I(A,X,Y) 1:", cmi123 - mi12)
print("I(A,X,Y) 2:", cmi132 - mi13)
print("I(A,X,Y) av:", I_av)
print()
print("Analysis (X=behavior1, Y=behavior2):")
print(f" Redundancy detected: {I_av < 0}")
print(f" MI(A,X) < |II|: {mi12 < np.abs(I_av)}")
print(f" MI(A,Y) < |II|: {mi13 < np.abs(I_av)}")
if I_av < 0: # Negative interaction information (redundancy)
criterion1 = mi12 < np.abs(I_av) and not cmi132 < np.abs(I_av)
criterion2 = mi13 < np.abs(I_av) and not cmi123 < np.abs(I_av)
if criterion1 and not criterion2:
return 1 # ts2 is redundant, ts3 is primary
elif criterion2 and not criterion1:
return 0 # ts3 is redundant, ts2 is primary
# Fallback: use CMI ratios when the strict criteria can't pick a winner.
# A feature is redundant if conditioning on the other removes most of
# its MI (residual ratio below threshold).
if mi12 > MI_EPSILON and mi13 > MI_EPSILON:
ratio_ts2 = cmi123 / mi12 # ts2's residual fraction
ratio_ts3 = cmi132 / mi13 # ts3's residual fraction
if ratio_ts2 < REDUNDANCY_CMI_RATIO and ratio_ts3 >= REDUNDANCY_CMI_RATIO:
return 1 # ts2 is redundant
if ratio_ts3 < REDUNDANCY_CMI_RATIO and ratio_ts2 >= REDUNDANCY_CMI_RATIO:
return 0 # ts3 is redundant
# Both highly redundant — remove the one with lower MI
if ratio_ts2 < REDUNDANCY_CMI_RATIO and ratio_ts3 < REDUNDANCY_CMI_RATIO:
if mi12 > mi13:
return 0 # ts2 has higher MI, keep it
elif mi13 > mi12:
return 1 # ts3 has higher MI, keep it
return 0.5 # Both contribute - undistinguishable
else: # Positive interaction information (synergy)
# Use epsilon tolerance for near-zero MI comparisons
if mi13 < MI_EPSILON and cmi123 > cmi132:
return 0 # ts2 is primary (ts3 has negligible MI)
if mi12 < MI_EPSILON and cmi132 > cmi123:
return 1 # ts3 is primary (ts2 has negligible MI)
if mi13 >= MI_EPSILON and mi12 / mi13 > DOMINANCE_RATIO_THRESHOLD and cmi123 > cmi132:
return 0 # ts2 is strongly dominant
if mi12 >= MI_EPSILON and mi13 / mi12 > DOMINANCE_RATIO_THRESHOLD and cmi132 > cmi123:
return 1 # ts3 is strongly dominant
return 0.5 # Both contribute - undistinguishable
[docs]
def disentangle_pair(ts1, ts2, ts3, verbose=False, ds=1):
"""Disentangle mixed selectivity between two behavioral variables for a neuron.
Determines which of two correlated behavioral variables (ts2, ts3) provides
the primary information about neural activity (ts1) using interaction information
and conditional mutual information analysis.
Parameters
----------
ts1 : TimeSeries
Neural activity time series (e.g., calcium signal or spike train).
ts2 : TimeSeries
First behavioral variable.
ts3 : TimeSeries
Second behavioral variable.
verbose : bool, optional
If True, print detailed analysis results. Default: False.
ds : int, optional
Downsampling factor. Default: 1.
Returns
-------
float
Disentanglement result:
- 0: ts2 is the primary variable (ts3 is redundant)
- 1: ts3 is the primary variable (ts2 is redundant)
- 0.5: Both variables contribute - undistinguishable
Notes
-----
The method uses interaction information to detect redundancy/synergy:
- If II < 0 (redundancy), identifies the "weakest link" using criteria
based on pairwise MI and conditional MI values
- If II > 0 (synergy), uses different criteria for special cases
See docs/intense_mathematical_framework.md for theoretical background."""
# Delegate to internal function (no pre-computed values)
return _disentangle_pair_with_precomputed(
ts1, ts2, ts3, mi12=None, mi13=None, mi23=None, verbose=verbose, ds=ds
)
def _process_neuron_disentanglement(
neuron_id,
sels,
neur_ts,
feat_names,
multifeature_map,
multifeature_ts,
feature_ts_dict,
ds,
feat_feat_significance,
cell_feat_stats,
feat_feat_similarity,
pre_decisions=None,
pre_renames=None,
feat_copnorm_cache=None,
):
"""Process disentanglement for a single neuron.
Worker function for parallel disentanglement processing. Analyzes all
feature pairs for a single neuron and returns partial result matrices.
Parameters
----------
neuron_id : any
Neuron identifier.
sels : list
List of feature selectivities for this neuron (already filtered by pre-filter).
neur_ts : TimeSeries
Neural activity time series.
feat_names : list
List of all feature names.
multifeature_map : dict
Mapping from multifeature tuples to aggregated names.
multifeature_ts : dict
Pre-built MultiTimeSeries objects for multifeatures.
feature_ts_dict : dict
Pre-extracted feature TimeSeries, keyed by feature name.
ds : int
Downsampling factor.
feat_feat_significance : ndarray or None
Binary significance matrix for feature pairs.
cell_feat_stats : dict or None
Pre-computed neuron-feature MI values.
feat_feat_similarity : ndarray or None
Pre-computed feature-feature MI values.
pre_decisions : dict or None, optional
Pre-computed pair decisions from filter chain: {(feat_i, feat_j): 0/0.5/1}.
0 = feat_i is primary, 1 = feat_j is primary, 0.5 = keep both.
Default: None.
pre_renames : dict or None, optional
Pre-computed feature renames from filter chain: {new_name: (old1, old2)}.
Default: None.
feat_copnorm_cache : dict or None, optional
Pre-computed downsampled copula-normalized data for features.
Maps feature name to downsampled copula data. Pre-computed once
and shared across all neurons to avoid redundant computation.
Default: None.
Returns
-------
neuron_id : any
The neuron identifier (passed through for result aggregation).
partial_disent : ndarray
Partial disentanglement matrix for this neuron.
partial_count : ndarray
Partial count matrix for this neuron.
neuron_info : dict
Per-neuron details containing:
- 'pairs': {(feat_i, feat_j): {'result': 0/0.5/1, 'source': str}}
- 'renames': {new_name: (old1, old2)} from pre_renames
- 'final_sels': list of final selectivities after filtering
- 'errors': list of (neuron_id, sel_comb, error_msg) tuples for failed pairs
"""
pre_decisions = pre_decisions or {}
pre_renames = pre_renames or {}
n_features = len(feat_names)
partial_disent = np.zeros((n_features, n_features))
partial_count = np.zeros((n_features, n_features))
# Track per-neuron pair results and errors
neuron_pairs_dict = {}
errors = []
# Build set of renamed feature names for skip logic
renamed_features = set(pre_renames.keys())
# Pre-cache downsampled copula-normalized data for faster CMI computation
# Only for continuous time series (discrete use different code paths)
neur_copnorm = _downsample_copnorm(neur_ts.copula_normal_data, ds) if not neur_ts.discrete else None
# Use pre-computed feature copula cache (passed from caller)
feat_copnorm_cache = feat_copnorm_cache or {}
# Test all pairs of features this neuron responds to
for sel_comb in combinations(sels, 2):
try:
sel_comb = list(sel_comb)
# Skip pairs involving renamed combined features
if sel_comb[0] in renamed_features or sel_comb[1] in renamed_features:
continue
feat_ts = []
finds = []
# Get time series for each feature
for fname in sel_comb:
# Check if this is a multifeature tuple
if isinstance(fname, tuple) and fname in multifeature_map:
agg_name = multifeature_map[fname]
if agg_name in feat_names:
feat_ts.append(multifeature_ts[agg_name])
finds.append(feat_names.index(agg_name))
else:
raise ValueError(f"Aggregated name '{agg_name}' not in feat_names")
else:
# Regular single feature - use pre-extracted dict
if fname in feature_ts_dict:
feat_ts.append(feature_ts_dict[fname])
finds.append(feat_names.index(fname))
else:
raise ValueError(f"Feature '{fname}' not found in experiment")
# Get feature indices
ind1 = finds[0]
ind2 = finds[1]
# Check if this pair has a pre-computed decision from filter chain
pair_key = (sel_comb[0], sel_comb[1])
reverse_key = (sel_comb[1], sel_comb[0])
if pair_key in pre_decisions:
disres = pre_decisions[pair_key]
source = 'pre_filter'
elif reverse_key in pre_decisions:
# Flip the result for reversed pair (0↔1, 0.5 unchanged)
disres = _flip_decision(pre_decisions[reverse_key])
source = 'pre_filter'
else:
# Check if this feature pair has significant behavioral correlation
if feat_feat_significance is not None:
if feat_feat_significance[ind1, ind2] == 0:
# Features are not significantly correlated
# Skip disentanglement - this is true mixed selectivity
disres = 0.5
source = 'not_significant'
partial_count[ind1, ind2] += 1
partial_count[ind2, ind1] += 1
partial_disent[ind1, ind2] += 0.5
partial_disent[ind2, ind1] += 0.5
# Record the pair result
neuron_pairs_dict[(feat_names[ind1], feat_names[ind2])] = {
'result': disres,
'source': source,
}
continue
# Look up pre-computed MI values (if available)
mi12 = _lookup_cell_feat_mi(cell_feat_stats, neuron_id, sel_comb[0])
mi13 = _lookup_cell_feat_mi(cell_feat_stats, neuron_id, sel_comb[1])
mi23 = _lookup_feat_feat_mi(
feat_feat_similarity, feat_names,
feat_names[ind1], feat_names[ind2]
)
# Look up optimal delays (original frame units → downsampled)
delay1 = _lookup_cell_feat_delay(cell_feat_stats, neuron_id, sel_comb[0])
delay2 = _lookup_cell_feat_delay(cell_feat_stats, neuron_id, sel_comb[1])
# Perform disentanglement analysis only for significant pairs
# Pass pre-computed copula data for faster CMI computation
disres = _disentangle_pair_with_precomputed(
neur_ts, feat_ts[0], feat_ts[1],
mi12=mi12, mi13=mi13, mi23=mi23,
ts1_copnorm=neur_copnorm,
ts2_copnorm=feat_copnorm_cache.get(sel_comb[0]),
ts3_copnorm=feat_copnorm_cache.get(sel_comb[1]),
ds=ds, verbose=False,
delay_feat1=delay1 // ds if ds > 1 else delay1,
delay_feat2=delay2 // ds if ds > 1 else delay2,
)
source = 'standard'
# Validate disres value
if disres not in VALID_DISRES_VALUES:
raise ValueError(f"Unexpected disres value: {disres} (expected 0, 0.5, or 1)")
# Update matrices
partial_count[ind1, ind2] += 1
partial_count[ind2, ind1] += 1
if disres == 0:
partial_disent[ind1, ind2] += 1 # Feature 1 is primary
elif disres == 1:
partial_disent[ind2, ind1] += 1 # Feature 2 is primary
else: # disres == 0.5 (validated above)
partial_disent[ind1, ind2] += 0.5 # Both contribute
partial_disent[ind2, ind1] += 0.5
# Record the pair result
neuron_pairs_dict[(feat_names[ind1], feat_names[ind2])] = {
'result': disres,
'source': source,
}
except (ValueError, AttributeError, KeyError) as e:
# Accumulate errors for reporting (don't just print and lose them)
errors.append((neuron_id, sel_comb, str(e)))
continue
# Compute final_sels by applying pair decisions
# result=0: feat_i is primary (remove feat_j)
# result=1: feat_j is primary (remove feat_i)
# result=0.5: keep both
features_to_remove = set()
for (feat_i, feat_j), info in neuron_pairs_dict.items():
result = info.get('result', 0.5)
if result == 0:
features_to_remove.add(feat_j)
elif result == 1:
features_to_remove.add(feat_i)
final_sels = [f for f in sels if f not in features_to_remove]
# Build neuron info
neuron_info = {
'pairs': neuron_pairs_dict,
'renames': pre_renames,
'final_sels': final_sels,
'errors': errors, # List of (neuron_id, sel_comb, error_msg) tuples
}
return neuron_id, partial_disent, partial_count, neuron_info
[docs]
def disentangle_all_selectivities(
exp,
feat_names,
ds=1,
multifeature_map=None,
feat_feat_significance=None,
cell_bunch=None,
cell_feat_stats=None,
feat_feat_similarity=None,
n_jobs=-1,
pre_filter_func=None,
post_filter_func=None,
filter_kwargs=None,
):
"""Analyze mixed selectivity across all significant neuron-feature pairs.
For each neuron that responds to multiple features, determines which
features provide primary vs redundant information using disentanglement
analysis. Only analyzes feature pairs that show significant correlation
in the behavioral data.
Parameters
----------
exp : Experiment
Experiment object containing neural and behavioral data.
feat_names : list of str
List of feature names to analyze. Should match features in experiment
and any aggregated names from multifeature_map.
ds : int, optional
Downsampling factor. Default: 1.
multifeature_map : dict, optional
Mapping from multifeature tuples to aggregated names and their
corresponding MultiTimeSeries. If None, uses DEFAULT_MULTIFEATURE_MAP.
Example: ``{('x', 'y'): 'place', ('speed', 'head_direction'): 'locomotion'}``.
feat_feat_significance : ndarray, optional
Binary significance matrix from compute_feat_feat_significance.
If provided, only feature pairs marked as significant (value=1)
will be analyzed for disentanglement. Non-significant pairs are
assumed to represent true mixed selectivity.
cell_bunch : list or None, optional
List of cell IDs to analyze. If None, analyzes all cells.
Default: None.
cell_feat_stats : dict or None, optional
Pre-computed neuron-feature statistics from INTENSE analysis.
Structure: stats[cell_id][feat_name]["me"] = MI value.
If provided, MI(neuron, feature) values will be looked up instead
of recomputed, significantly speeding up disentanglement.
Default: None.
feat_feat_similarity : ndarray or None, optional
Pre-computed feature-feature similarity matrix from
compute_feat_feat_significance. Matrix where [i, j] = MI(feat_i, feat_j).
If provided, MI(feature1, feature2) values will be looked up.
Default: None.
n_jobs : int, optional
Number of parallel jobs for processing neurons. -1 means use all
available processors. Default: -1.
pre_filter_func : callable or None, optional
Population-level filter function (or composed filter) to run BEFORE
the parallel processing loop. The filter mutates neuron selectivities
and pre-computes pair decisions for all neurons at once.
Signature::
def pre_filter_func(
neuron_selectivities, # dict: {neuron_id: [feat1, feat2, ...]} - MUTATE
pair_decisions, # dict: {neuron_id: {(f1, f2): 0/0.5/1}} - MUTATE
renames, # dict: {neuron_id: {new_name: (old1, old2)}} - MUTATE
cell_feat_stats, # Pre-computed MI values (READ ONLY)
feat_feat_significance, # Binary matrix (READ ONLY)
feat_names, # List of feature names (READ ONLY)
**kwargs, # User-provided extra arguments
):
...
Default: None (no filtering).
post_filter_func : callable or None, optional
Population-level filter function to run AFTER parallel disentanglement.
Can modify pair results (e.g., tie-breaking). Mutates per_neuron_disent
and recalculates final_sels.
Signature::
def post_filter_func(
per_neuron_disent, # dict: {nid: {'pairs': {...}, ...}} - MUTATE
cell_feat_stats, # Pre-computed MI values (READ ONLY)
feat_names, # List of feature names (READ ONLY)
**kwargs, # User-provided extra arguments
):
...
Default: None (no post-filtering).
filter_kwargs : dict or None, optional
Dictionary of keyword arguments to pass to pre_filter_func and post_filter_func.
Can include pre-extracted data like calcium_data, feature_data,
thresholds, etc. Default: None.
Returns
-------
dict
Dictionary containing:
- 'disent_matrix': ndarray where element [i,j] indicates how many times
feature i was primary when paired with feature j across all neurons.
- 'count_matrix': ndarray where element [i,j] indicates how many
neuron-feature pairs were tested for features i and j.
- 'per_neuron_disent': dict mapping neuron_id to detailed results
with keys 'pairs', 'renames', 'final_sels', and 'errors'.
Notes
-----
The analysis is performed only on neurons with significant selectivity
to at least 2 features. If feat_feat_significance is provided, only
behaviorally correlated feature pairs are analyzed for redundancy.
Non-significant pairs indicate true mixed selectivity.
When cell_feat_stats and feat_feat_similarity are provided, 3 out of 5
pairwise MI computations per disentangle_pair call are skipped by using
lookups, providing ~60% reduction in MI computation overhead.
The neuron loop is parallelized using joblib, providing significant speedup
when analyzing many neurons. Each neuron is processed independently and
results are merged at the end.
Feature names in ``feat_names`` must use ``_2d``-substituted names for circular
features (e.g., ``headdirection_2d`` not ``headdirection``), matching the names
used in ``cell_feat_stats``.
If any per-neuron disentanglement encounters errors (e.g., missing feature data),
a summary warning is emitted via ``warnings.warn`` after all neurons are processed.
**Filter chain execution:**
1. Filters run at population level BEFORE the parallel loop
2. Filters mutate neuron_selectivities, pair_decisions, and renames in place
3. Workers receive pre-computed decisions (lightweight serialization)
4. Each filter in the chain can override decisions from earlier filters
Raises
------
ValueError
If a feature name is not found in the experiment or feat_names.
AttributeError
If required attributes are missing from the experiment.
KeyError
If expected keys are missing from data structures."""
# Use default multifeature mapping if none provided
if multifeature_map is None:
multifeature_map = DEFAULT_MULTIFEATURE_MAP.copy()
# Initialize result matrices
n_features = len(feat_names)
disent_matrix = np.zeros((n_features, n_features))
count_matrix = np.zeros((n_features, n_features))
# Create MultiTimeSeries for each multifeature
multifeature_ts = {}
for mf_tuple, agg_name in multifeature_map.items():
if agg_name in feat_names:
# Get individual TimeSeries for each component
component_ts = []
for component in mf_tuple:
if hasattr(exp, component):
component_ts.append(getattr(exp, component))
else:
raise ValueError(f"Component '{component}' not found in experiment")
# Create MultiTimeSeries
# Allow zero columns since behavioral features might be constant
multifeature_ts[agg_name] = MultiTimeSeries(component_ts, allow_zero_columns=True)
# Get neurons with significant selectivity to multiple features
sneur = exp.get_significant_neurons(min_nspec=2, cbunch=cell_bunch)
# Pre-extract feature TimeSeries to avoid serializing entire exp object
# This is critical for parallel performance with joblib
feature_ts_dict = {}
for fname in feat_names:
if hasattr(exp, fname):
feature_ts_dict[fname] = getattr(exp, fname)
# Pre-extract neuron TimeSeries
neuron_ts_dict = {neuron: exp.neurons[neuron].ca for neuron in sneur.keys()}
# ============================================================
# PHASE 1: Build filter state and run filter chain (BEFORE parallel loop)
# ============================================================
# Build mutable state for filter chain
neuron_selectivities = {nid: list(sels) for nid, sels in sneur.items()}
pair_decisions = {nid: {} for nid in sneur}
renames = {nid: {} for nid in sneur}
# Run filter chain (population-level, BEFORE parallel loop)
if pre_filter_func is not None:
pre_filter_func(
neuron_selectivities=neuron_selectivities,
pair_decisions=pair_decisions,
renames=renames,
cell_feat_stats=cell_feat_stats,
feat_feat_significance=feat_feat_significance,
feat_names=feat_names,
**(filter_kwargs or {}),
)
# ============================================================
# PHASE 2: Parallel processing (WORKERS)
# ============================================================
per_neuron_disent = {}
# Pre-compute feature copula cache ONCE (shared across all neurons)
# This avoids redundant copula normalization in each worker
feat_copnorm_cache = {}
for fname in feat_names:
ts = feature_ts_dict.get(fname) or multifeature_ts.get(fname)
if ts is not None and not getattr(ts, 'discrete', True):
feat_copnorm_cache[fname] = _downsample_copnorm(ts.copula_normal_data, ds)
# Process neurons in parallel with backend-specific config
if len(neuron_selectivities) > 0:
with _parallel_executor(n_jobs) as parallel:
results = parallel(
delayed(_process_neuron_disentanglement)(
neuron_id=neuron,
sels=neuron_selectivities[neuron], # Already filtered
neur_ts=neuron_ts_dict[neuron],
feat_names=feat_names,
multifeature_map=multifeature_map,
multifeature_ts=multifeature_ts,
feature_ts_dict=feature_ts_dict,
ds=ds,
feat_feat_significance=feat_feat_significance,
cell_feat_stats=cell_feat_stats,
feat_feat_similarity=feat_feat_similarity,
pre_decisions=pair_decisions[neuron], # Pre-computed
pre_renames=renames[neuron], # Pre-computed
feat_copnorm_cache=feat_copnorm_cache, # Pre-computed
)
for neuron in neuron_selectivities.keys()
)
# Merge partial results from all workers
for neuron_id, partial_disent, partial_count, neuron_info in results:
disent_matrix += partial_disent
count_matrix += partial_count
# Store per-neuron info if it has any content (including errors)
if neuron_info['pairs'] or neuron_info['renames'] or neuron_info['errors']:
per_neuron_disent[neuron_id] = neuron_info
# Surface any accumulated errors from per-neuron processing
all_errors = []
for nid, info in per_neuron_disent.items():
for err in info.get('errors', []):
all_errors.append(err)
if all_errors:
neurons_with_errors = len({e[0] for e in all_errors})
examples = "; ".join(
f"neuron {e[0]}, pair {e[1]}: {e[2]}" for e in all_errors[:5]
)
suffix = f" (showing 5 of {len(all_errors)})" if len(all_errors) > 5 else ""
warnings.warn(
f"Disentanglement encountered {len(all_errors)} errors across "
f"{neurons_with_errors} neurons{suffix}: {examples}",
stacklevel=2,
)
# ============================================================
# PHASE 3: Post-filter (AFTER parallel loop)
# ============================================================
if post_filter_func is not None:
post_filter_func(
per_neuron_disent=per_neuron_disent,
cell_feat_stats=cell_feat_stats,
feat_names=feat_names,
**(filter_kwargs or {}),
)
# Recalculate final_sels for neurons modified by post-filter
for nid, neuron_info in per_neuron_disent.items():
features_to_remove = set()
for (feat_i, feat_j), info in neuron_info['pairs'].items():
result = info.get('result', 0.5)
if result == 0:
features_to_remove.add(feat_j)
elif result == 1:
features_to_remove.add(feat_i)
original_sels = neuron_selectivities[nid]
neuron_info['final_sels'] = [f for f in original_sels if f not in features_to_remove]
return {
'disent_matrix': disent_matrix,
'count_matrix': count_matrix,
'per_neuron_disent': per_neuron_disent,
}
[docs]
def create_multifeature_map(exp, mapping_dict):
"""Create a multifeature mapping with validation.
Parameters
----------
exp : Experiment
Experiment object to validate feature existence.
mapping_dict : dict
Dictionary mapping tuples of features to aggregated names.
Example: {('x', 'y'): 'place', ('speed', 'head_direction'): 'locomotion'}
Returns
-------
dict
Validated multifeature mapping.
Raises
------
ValueError
If any component features don't exist in the experiment."""
validated_map = {}
for mf_tuple, agg_name in mapping_dict.items():
# Validate that all components exist
for component in mf_tuple:
if not hasattr(exp, component):
raise ValueError(
f"Component '{component}' in multifeature {mf_tuple} "
f"not found in experiment"
)
# Ensure tuple is sorted for consistency
sorted_tuple = tuple(sorted(mf_tuple))
validated_map[sorted_tuple] = agg_name
return validated_map
[docs]
def get_disentanglement_summary(
disent_matrix, count_matrix, feat_names, feat_feat_significance=None,
per_neuron_disent=None,
):
"""Generate a summary of disentanglement results.
Parameters
----------
disent_matrix : ndarray
Disentanglement result matrix from disentangle_all_selectivities.
count_matrix : ndarray
Count matrix from disentangle_all_selectivities.
feat_names : list of str
Feature names corresponding to matrix indices.
feat_feat_significance : ndarray, optional
Binary significance matrix indicating which feature pairs
were analyzed for disentanglement.
per_neuron_disent : dict, optional
Per-neuron disentanglement results from disentangle_all_selectivities.
When provided, redundancy/undistinguishable counts are computed exactly
from per-neuron pair decisions instead of the aggregate matrices.
Returns
-------
dict
Summary statistics including:
- Primary feature percentages for each pair
- Total counts for each pair
- Overall redundancy vs independence rates
- Breakdown by significant vs non-significant feature pairs
Notes
-----
The calculation distinguishes between:
- Redundant cases: One feature is primary (disentangle result 0 or 1)
- Undistinguishable cases: Both features contribute (disentangle result 0.5)"""
summary = {"feature_pairs": {}, "overall_stats": {}}
n_features = len(feat_names)
# Build exact per-pair counts from per_neuron_disent when available.
# The aggregate matrices lose information when an even number of neurons
# all get result=0.5, making it indistinguishable from actual wins.
pair_counts = None
if per_neuron_disent is not None:
# pair_counts[(fi, fj)] = {0: n_fi_primary, 1: n_fj_primary, 0.5: n_undist}
pair_counts = {}
for neuron_info in per_neuron_disent.values():
for (fi, fj), pinfo in neuron_info.get('pairs', {}).items():
result = pinfo.get('result', 0.5)
key = (fi, fj)
if key not in pair_counts:
pair_counts[key] = {0: 0, 0.5: 0, 1: 0}
pair_counts[key][result] += 1
total_redundant = 0
total_undistinguishable = 0
total_pairs = 0
for i in range(n_features):
for j in range(i + 1, n_features):
if count_matrix[i, j] > 0:
n_total = count_matrix[i, j]
n_i_primary = disent_matrix[i, j]
n_j_primary = disent_matrix[j, i]
# Count undistinguishable cases
if pair_counts is not None:
# Exact counts from per-neuron results
fi, fj = feat_names[i], feat_names[j]
key = (fi, fj)
rev_key = (fj, fi)
if key in pair_counts:
pc = pair_counts[key]
elif rev_key in pair_counts:
pc = pair_counts[rev_key]
# Swap 0 and 1 for reversed key
pc = {0: pc[1], 1: pc[0], 0.5: pc[0.5]}
else:
pc = {0: 0, 0.5: int(n_total), 1: 0}
n_undistinguishable = pc[0.5]
elif feat_feat_significance is not None and feat_feat_significance[i, j] == 0:
n_undistinguishable = int(n_total)
else:
# Fallback: estimate from aggregate matrices (may be inaccurate
# when an even number of neurons all have result=0.5)
frac_i = n_i_primary - int(n_i_primary)
if frac_i > 0.25:
n_undistinguishable = round(frac_i * 2)
else:
n_undistinguishable = int(
n_total - int(n_i_primary) - int(n_j_primary)
)
n_redundant = int(n_total) - n_undistinguishable
pair_key = f"{feat_names[i]}_vs_{feat_names[j]}"
summary["feature_pairs"][pair_key] = {
"total_neurons": int(n_total),
f"{feat_names[i]}_primary": n_i_primary / n_total * 100,
f"{feat_names[j]}_primary": n_j_primary / n_total * 100,
"undistinguishable_pct": n_undistinguishable / n_total * 100,
"redundant_pct": n_redundant / n_total * 100,
}
total_redundant += n_redundant
total_undistinguishable += n_undistinguishable
total_pairs += n_total
if total_pairs > 0:
summary["overall_stats"] = {
"total_neuron_pairs": int(total_pairs),
"redundancy_rate": total_redundant / total_pairs * 100,
"undistinguishable_rate": total_undistinguishable / total_pairs * 100,
}
# Add breakdown by behavioral significance if provided
if feat_feat_significance is not None:
sig_pairs = 0
nonsig_pairs = 0
for i in range(n_features):
for j in range(i + 1, n_features):
if count_matrix[i, j] > 0:
if feat_feat_significance[i, j] == 1:
sig_pairs += count_matrix[i, j]
else:
nonsig_pairs += count_matrix[i, j]
summary["overall_stats"]["significant_behavior_pairs"] = int(sig_pairs)
summary["overall_stats"]["nonsignificant_behavior_pairs"] = int(nonsig_pairs)
summary["overall_stats"]["true_mixed_selectivity_rate"] = (
nonsig_pairs / total_pairs * 100 if total_pairs > 0 else 0
)
return summary