import time
import multiprocessing
import warnings
from contextlib import contextmanager
import numpy as np
import tqdm
from dataclasses import dataclass
from typing import Callable, Optional
from joblib import Parallel, delayed
from .stats import (
populate_nested_dict,
get_table_of_stats,
criterion1,
criterion2,
apply_stage_criterion,
get_all_nonempty_pvals,
merge_stage_stats,
merge_stage_significance,
DEFAULT_METRIC_DISTR_TYPE,
)
from ..information.info_base import (
TimeSeries,
MultiTimeSeries,
get_multi_mi,
get_sim,
)
from ..utils.data import nested_dict_to_seq_of_tables, add_names_to_nested_dict
from .io import IntenseResults
from .validation import (
validate_time_series_bunches,
validate_metric,
validate_common_parameters,
)
from .fft import (
MIN_SHUFFLES_FOR_FFT,
MIN_SHIFTS_FOR_FFT_DELAYS,
MAX_FFT_MTS_DIMENSIONS,
MAX_MTS_MTS_FFT_DIMENSIONS,
FFT_CONTINUOUS, FFT_DISCRETE, FFT_DISCRETE_DISCRETE,
FFT_MULTIVARIATE, FFT_MTS_MTS, FFT_MTS_DISCRETE,
FFT_PEARSON_CONTINUOUS, FFT_PEARSON_DISCRETE, FFT_AV_DISCRETE,
FFTCacheEntry,
get_fft_type,
_extract_fft_data,
_FFT_COMPUTE,
_get_ts_key,
_build_fft_cache,
)
from .correction import get_multicomp_correction_thr
# Import shared parallel utilities
# Note: _parallel_executor is now in utils.parallel for shared use across modules
from ..utils.parallel import parallel_executor as _parallel_executor, get_parallel_backend as _get_parallel_backend
# Default noise amplitude added to MI values for numerical stability
DEFAULT_NOISE_AMPLITUDE = 1e-3
@contextmanager
def _timed_section(timings, name):
"""Context manager for timing code sections. No-op if timings is None.
Parameters
----------
timings : dict or None
Dictionary to store timing results. If None, timing is skipped (no-op).
name : str
Key under which the elapsed time (in seconds) will be stored in timings.
"""
if timings is None:
yield
else:
start = time.perf_counter()
yield
timings[name] = time.perf_counter() - start
def _build_shift_valid_map(ts_bunch1, ts_bunch2, optimal_delays, ds):
"""
Build boolean map of valid shift indices per pair from shuffle masks.
For each pair (i, j), combines the shuffle masks of ts_bunch1[i] and
ts_bunch2[j], rolls by the optimal delay, and marks which downsampled
shift indices are valid. The result is a 3D boolean array that can be
used for vectorized validity checking.
Cache keys use _get_ts_key() (ts.name) for stability across pickling,
consistent with fft_cache keying throughout INTENSE.
Parameters
----------
ts_bunch1 : list
First set of time series (e.g., neurons).
ts_bunch2 : list
Second set of time series (e.g., features).
optimal_delays : np.ndarray
Optimal delays of shape (len(ts_bunch1), len(ts_bunch2)).
ds : int
Downsampling factor.
Returns
-------
valid_map : np.ndarray, dtype=bool
Shape (n1, n2, n_shifts). True means shift index s is valid for pair (i,j).
needs_correction : bool
False if all masks are trivial (all shifts valid for all pairs).
"""
n1, n2 = len(ts_bunch1), len(ts_bunch2)
n_frames = ts_bunch1[0].data.shape[-1]
n_shifts = n_frames // ds
valid_map = np.ones((n1, n2, n_shifts), dtype=bool)
needs_correction = False
# Cache by (key1, key2, delay) — pairs sharing the same mask+delay
# combination (common: all neurons share one mask, features unmasked)
# only compute the valid set once.
_cache = {}
for i, ts1 in enumerate(ts_bunch1):
key1 = _get_ts_key(ts1)
for j, ts2 in enumerate(ts_bunch2):
key2 = _get_ts_key(ts2)
delay = int(optimal_delays[i, j])
ck = (key1, key2, delay)
if ck not in _cache:
combined = ts1.shuffle_mask & ts2.shuffle_mask
if delay != 0:
combined = np.roll(combined, delay)
if np.all(combined):
_cache[ck] = None # all shifts valid
else:
raw_inv = np.unique(np.where(~combined)[0] // ds)
_cache[ck] = raw_inv[raw_inv < n_shifts]
needs_correction = True
inv = _cache[ck]
if inv is not None:
valid_map[i, j, inv] = False
return valid_map, needs_correction
def _find_invalid_shifts(random_shifts, valid_map):
"""
Find shifts that land on masked (invalid) positions.
Uses advanced indexing to check all (n1 x n2 x nsh) shifts at once
against the validity map.
Parameters
----------
random_shifts : np.ndarray, shape (n1, n2, nsh)
Shift indices to check.
valid_map : np.ndarray, shape (n1, n2, n_shifts), dtype=bool
Validity map from _build_shift_valid_map.
Returns
-------
bad : np.ndarray, shape (n1, n2, nsh), dtype=bool
True where the shift is invalid.
"""
ii = np.arange(valid_map.shape[0])[:, None, None]
jj = np.arange(valid_map.shape[1])[None, :, None]
return ~valid_map[ii, jj, random_shifts]
def _generate_random_shifts_grid(ts_bunch1, ts_bunch2, optimal_delays, nsh, seed, ds=1):
"""
Generate all random shifts upfront for all pairs.
Uses vectorized bulk generation followed by rejection resampling to
respect shuffle masks. Much faster than per-pair RandomState construction.
Parameters
----------
ts_bunch1 : list
First set of time series (e.g., neurons).
ts_bunch2 : list
Second set of time series (e.g., features).
optimal_delays : np.ndarray
Optimal delays of shape (len(ts_bunch1), len(ts_bunch2)).
nsh : int
Number of random shifts to generate per pair.
seed : int
Base random seed for reproducibility.
ds : int, default=1
Downsampling factor.
Returns
-------
random_shifts : np.ndarray
Array of shape (len(ts_bunch1), len(ts_bunch2), nsh) containing
pre-generated random shifts for each pair.
Notes
-----
- Uses _build_shift_valid_map + _find_invalid_shifts for mask correction
- Respects shuffle masks via rejection resampling (converges in 2-3 rounds)
"""
n1, n2 = len(ts_bunch1), len(ts_bunch2)
n_frames = ts_bunch1[0].data.shape[-1]
n_shifts = n_frames // ds
rng = np.random.RandomState(seed)
# Generate shifts: exhaustive when nsh covers all unique shifts, random otherwise
if nsh >= n_shifts > 0:
base = np.arange(n_shifts, dtype=np.int32)
random_shifts = np.tile(base, (n1, n2, 1))
else:
random_shifts = rng.randint(0, n_shifts, size=(n1, n2, nsh)).astype(np.int32)
# Build validity map from shuffle masks (once)
valid_map, needs_correction = _build_shift_valid_map(
ts_bunch1, ts_bunch2, optimal_delays, ds
)
# Rejection loop: find invalid shifts, replace them, repeat until convergence
if needs_correction:
for _ in range(100):
bad = _find_invalid_shifts(random_shifts, valid_map)
n_bad = bad.sum()
if n_bad == 0:
break
random_shifts[bad] = rng.randint(0, n_shifts, size=n_bad).astype(np.int32)
return random_shifts
[docs]
@dataclass
class StageConfig:
"""Configuration for a single stage of INTENSE computation.
Encapsulates all stage-specific parameters to enable unified
scan_stage() function for both Stage 1 and Stage 2.
Attributes
----------
stage_num : int
Stage number (1 or 2).
n_shuffles : int
Number of shuffles for this stage.
mask : np.ndarray
Binary mask indicating which pairs to compute.
topk : int
True MI should rank in top k among shuffles.
pval_thr : float, optional
Base p-value threshold (Stage 2 only). Default 0.05.
multicomp_correction : str, optional
Multiple comparison correction method (Stage 2 only).
Options: 'holm', 'bonferroni', etc.
"""
stage_num: int
n_shuffles: int
mask: np.ndarray
topk: int
pval_thr: Optional[float] = None
multicomp_correction: Optional[str] = None
def _extract_cache_subset(cache, ts_subset, ts_bunch2):
"""Extract cache entries for a specific subset of ts_bunch1.
Instead of passing the entire cache to each worker, extract only
the entries that worker needs. This avoids massive serialization
overhead when using parallel processing with large caches.
Parameters
----------
cache : dict or None
Full FFT cache mapping (key1, key2) -> FFTCacheEntry.
ts_subset : list of TimeSeries
Subset of ts_bunch1 that this worker will process.
ts_bunch2 : list of TimeSeries
Full set of ts_bunch2 (features).
Returns
-------
dict or None
Subset of cache containing only entries relevant to ts_subset,
or None if input cache is None.
"""
if cache is None:
return None
subset_cache = {}
for ts1 in ts_subset:
key1 = _get_ts_key(ts1)
for ts2 in ts_bunch2:
key2 = _get_ts_key(ts2)
cache_key = (key1, key2)
if cache_key in cache:
subset_cache[cache_key] = cache[cache_key]
return subset_cache
[docs]
def get_calcium_feature_me_profile(
exp,
cell_id=None,
feat_id=None,
cbunch=None,
fbunch=None,
shift_window=2,
ds=1,
metric="mi",
mi_estimator="gcmi",
data_type="calcium",
) -> dict:
"""
Compute metric profile between neurons and behavioral features across time shifts.
Parameters
----------
exp : Experiment
Experiment object containing neurons and behavioral features.
cell_id : int, optional
Index of a single neuron in exp.neurons. Deprecated - use cbunch instead.
feat_id : str or tuple of str, optional
Single feature name(s) to analyze. Deprecated - use fbunch instead.
cbunch : int, iterable or None, optional
Neuron indices. If None (default), all neurons will be analyzed.
Takes precedence over cell_id if both provided.
fbunch : str, iterable or None, optional
Feature names. If None (default), all single features will be analyzed.
Takes precedence over feat_id if both provided.
shift_window : int, optional
Maximum shift to test in each direction (seconds). Default: 2.
Converted to frames internally using exp.fps.
ds : int, optional
Downsampling factor. Default: 1 (no downsampling).
metric : str, optional
Similarity metric to compute. Default: 'mi'.
- 'mi': Mutual information
- 'spearman': Spearman correlation
- Other metrics supported by get_sim function
mi_estimator : str, optional
Mutual information estimator to use when metric='mi'. Default: 'gcmi'.
Options: 'gcmi' or 'ksg'
data_type : str, optional
Type of neural data to use. Default: 'calcium'.
- 'calcium': Use calcium imaging data
- 'spikes': Use spike data
Returns
-------
dict
If single cell_id and feat_id provided (backward compatibility):
{'me0': float, 'shifted_me': list of float}
If cbunch or fbunch used:
Nested dictionary with structure:
{cell_id: {feat_id: {'me0': float, 'shifted_me': list}}}
where shifted_me contains metric values from -window to +window.
Notes
-----
- shift_window is in seconds, converted to frames using exp.fps
- Total number of shifts tested: 2 * shift_window * fps / ds
- Multi-feature analysis (tuple feat_id) only supported for metric='mi'
- Progress bar shows computation progress
Examples
--------
This function requires an Experiment object, which contains neural recordings
and behavioral features. Here's a conceptual example:
>>> # Pseudo-code example (requires actual Experiment object):
>>> # exp = load_experiment() # Load your experiment data
>>> #
>>> # # Analyze MI between neuron 0 and speed feature
>>> # me0, profile = get_calcium_feature_me_profile(exp, 0, 'speed',
>>> # window=100, ds=5)
>>> #
>>> # # Or analyze multiple neurons and features at once
>>> # results = get_calcium_feature_me_profile(exp, cbunch=[0, 1],
>>> # fbunch=['speed', 'direction'],
>>> # window=50, ds=2)
>>> # # Access results: results[neuron_id][feature_name]['me0']
>>> pass # Actual usage requires Experiment object"""
# Validate inputs
validate_common_parameters(ds=ds)
validate_metric(metric)
if shift_window <= 0:
raise ValueError(f"shift_window must be positive, got {shift_window}")
# Convert shift_window from seconds to frames
window = int(shift_window * exp.fps)
# Check if single cell/feature mode (backward compatibility)
single_mode = cell_id is not None and feat_id is not None and cbunch is None and fbunch is None
# Handle backward compatibility - if old-style single cell_id/feat_id provided
if cbunch is None and cell_id is not None:
cbunch = cell_id
if fbunch is None and feat_id is not None:
fbunch = feat_id
# Process cbunch and fbunch using experiment's methods
cell_ids = exp._process_cbunch(cbunch)
feat_ids = exp._process_fbunch(fbunch, allow_multifeatures=True, mode=data_type)
# Validate cell indices
for cid in cell_ids:
if not (0 <= cid < len(exp.neurons)):
raise ValueError(f"cell_id {cid} out of range [0, {len(exp.neurons)-1}]")
# Initialize results dictionary
results = {}
# Progress bar for all combinations
total_combinations = len(cell_ids) * len(feat_ids)
pbar = tqdm.tqdm(total=total_combinations, desc="Computing ME profiles")
for cid in cell_ids:
cell = exp.neurons[cid]
ts1 = cell.ca if data_type == "calcium" else cell.spikes
results[cid] = {}
for fid in feat_ids:
shifted_me = []
if isinstance(fid, str):
# Single feature
ts2 = exp.dynamic_features[fid]
me0 = get_sim(ts1, ts2, metric, ds=ds, estimator=mi_estimator)
for shift in np.arange(-window, window + ds, ds) // ds:
lag_me = get_sim(ts1, ts2, metric, ds=ds, shift=shift, estimator=mi_estimator)
shifted_me.append(lag_me)
else:
# Multi-feature (tuple)
if metric != "mi":
raise ValueError(
f"Multi-feature analysis only supported for metric='mi', got '{metric}'"
)
feats = [exp.dynamic_features[f] for f in fid]
me0 = get_multi_mi(feats, ts1, ds=ds, estimator=mi_estimator)
for shift in np.arange(-window, window + ds, ds) // ds:
lag_me = get_multi_mi(feats, ts1, ds=ds, shift=shift, estimator=mi_estimator)
shifted_me.append(lag_me)
results[cid][fid] = {"me0": me0, "shifted_me": shifted_me}
pbar.update(1)
pbar.close()
# Return format based on usage mode
if single_mode:
# Backward compatibility - return simple format
return (
results[cell_ids[0]][feat_ids[0]]["me0"],
results[cell_ids[0]][feat_ids[0]]["shifted_me"],
)
else:
# New format - return full results dictionary
return results
[docs]
def scan_pairs(
ts_bunch1,
ts_bunch2,
metric,
nsh,
optimal_delays,
random_shifts=None,
mi_estimator="gcmi",
ds=1,
mask=None,
noise_const=DEFAULT_NOISE_AMPLITUDE,
seed=None,
enable_progressbar=True,
engine="auto",
fft_cache: dict = None,
mi_estimator_kwargs=None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Calculate similarity metric and shuffled distributions for pairs of time series.
This function computes the similarity metric between all pairs from ts_bunch1 and
ts_bunch2, along with shuffled distributions for significance testing.
Parameters
----------
ts_bunch1 : list of TimeSeries or MultiTimeSeries
First set of time series (typically neural signals).
ts_bunch2 : list of TimeSeries or MultiTimeSeries
Second set of time series (typically behavioral variables).
metric : str
Similarity metric to compute. See validate_metric for supported options.
nsh : int
Number of shuffles for significance testing.
optimal_delays : np.ndarray
Optimal delays array of shape (len(ts_bunch1), len(ts_bunch2)).
Contains best shifts in frames.
random_shifts : np.ndarray, optional
Pre-generated random shifts of shape (len(ts_bunch1), len(ts_bunch2), nsh).
If None, shifts will be generated using seed and stable keys.
mi_estimator : str, default='gcmi'
Mutual information estimator to use when metric='mi'.
Options: 'gcmi' (Gaussian copula) or 'ksg' (k-nearest neighbors).
ds : int, default=1
Downsampling factor. Every ds-th point is used from the time series.
mask : np.ndarray, optional
Binary mask array of shape (len(ts_bunch1), len(ts_bunch2)).
0 skips calculation, 1 proceeds.
noise_const : float, default=1e-3
Small noise amplitude added to improve numerical stability.
seed : int, optional
Random seed for reproducibility.
enable_progressbar : bool, default=True
Whether to show progress bar during computation.
engine : {'auto', 'fft', 'loop'}, default='auto'
Computation engine for MI shuffles:
- 'auto': Use FFT when applicable (univariate continuous GCMI with nsh >= 50)
- 'fft': Force FFT (raises error if not applicable)
- 'loop': Force per-shift loop (original behavior)
fft_cache : dict, optional
Pre-computed FFT cache from _build_fft_cache. Keys are (key1, key2) tuples
using stable identifiers from _get_ts_key(). If provided, avoids redundant
data extraction.
mi_estimator_kwargs : dict, optional
Additional keyword arguments passed to the MI estimator function.
Returns
-------
random_shifts : np.ndarray
Array of shape (len(ts_bunch1), len(ts_bunch2), nsh) containing
random shifts used for shuffled distribution computation.
me_total : np.ndarray
Array of shape (len(ts_bunch1), len(ts_bunch2), nsh+1). Contains true metric
values at index 0 and shuffled values at indices 1:nsh+1.
Notes
-----
- True metric values: me_total[:,:,0]
- Shuffled values: me_total[:,:,1:]
- Random shifts are drawn uniformly from time series length
- Noise is added as: value * (1 + noise_const * U(-1,1))
- FFT optimization provides ~100x speedup for univariate continuous GCMI"""
# Validate inputs
validate_time_series_bunches(ts_bunch1, ts_bunch2)
validate_metric(metric)
validate_common_parameters(ds=ds, nsh=nsh, noise_const=noise_const)
# Validate optimal_delays shape
n1 = len(ts_bunch1)
n2 = len(ts_bunch2)
if optimal_delays.shape != (n1, n2):
raise ValueError(
f"optimal_delays shape {optimal_delays.shape} doesn't match expected ({n1}, {n2})"
)
if seed is None:
seed = 0
# Note: Per-pair deterministic seeding uses pair_seed = seed + hash((key1, key2)) % 1000000
# This ensures reproducibility without polluting global RNG state.
# Only used for non-cached paths; cached paths use pre-generated noise arrays.
lengths1 = [
len(ts.data) if isinstance(ts, TimeSeries) else ts.data.shape[1] for ts in ts_bunch1
]
lengths2 = [
len(ts.data) if isinstance(ts, TimeSeries) else ts.data.shape[1] for ts in ts_bunch2
]
if len(set(lengths1)) == 1 and len(set(lengths2)) == 1 and set(lengths1) == set(lengths2):
t = lengths1[0] # full length is the same for all time series
else:
raise ValueError("Lenghts of TimeSeries do not match!")
if mask is None:
mask = np.ones((n1, n2))
me_table = np.zeros((n1, n2), dtype=np.float32)
me_table_shuffles = np.zeros((n1, n2, nsh), dtype=np.float32)
# Generate random shifts if not provided
if random_shifts is None:
# Vectorized bulk generation + rejection resampling
n_shifts = t // ds
_rng = np.random.RandomState(seed if seed is not None else 0)
if nsh >= n_shifts > 0:
base = np.arange(n_shifts, dtype=np.int32)
random_shifts = np.tile(base, (n1, n2, 1))
else:
random_shifts = _rng.randint(0, n_shifts, size=(n1, n2, nsh)).astype(np.int32)
valid_map, needs_correction = _build_shift_valid_map(
ts_bunch1, ts_bunch2, optimal_delays, ds
)
if needs_correction:
for _ in range(100):
bad = _find_invalid_shifts(random_shifts, valid_map)
n_bad = bad.sum()
if n_bad == 0:
break
random_shifts[bad] = _rng.randint(0, n_shifts, size=n_bad).astype(np.int32)
# Pre-generate noise for FFT cache path (avoids per-pair RandomState)
if fft_cache is not None:
_noise_rng = np.random.RandomState(seed)
_noise_true = _noise_rng.random(size=(n1, n2)).astype(np.float32) * noise_const
_noise_shuffles = _noise_rng.random(size=(n1, n2, nsh)).astype(np.float32) * noise_const
# calculate similarity metric arrays
for i, ts1 in tqdm.tqdm(
enumerate(ts_bunch1),
total=len(ts_bunch1),
position=0,
leave=True,
disable=not enable_progressbar,
):
if fft_cache is not None:
# Vectorized cache path: batch all cached pairs for this neuron
key1 = _get_ts_key(ts1)
cached_js = []
mi_all_list = []
uncached_js = []
for j in range(n2):
if mask[i, j] == 1:
entry = fft_cache.get((key1, _get_ts_key(ts_bunch2[j])))
if entry is not None:
cached_js.append(j)
mi_all_list.append(entry.mi_all)
else:
uncached_js.append(j)
# Batch process all cached pairs at once
if cached_js:
cached_js_arr = np.array(cached_js)
mi_stack = np.array(mi_all_list) # (n_cached, n_shifts)
arange_cached = np.arange(len(cached_js))
# Vectorized true MI lookup
opt_shifts = (optimal_delays[i, cached_js_arr] // ds).astype(int)
me0_vals = mi_stack[arange_cached, opt_shifts]
# Vectorized shuffle MI lookup
shifts_batch = random_shifts[i, cached_js_arr, :] # (n_cached, nsh)
shuffle_vals = mi_stack[arange_cached[:, None], shifts_batch]
# Write results with pre-generated noise
me_table[i, cached_js_arr] = me0_vals + _noise_true[i, cached_js_arr]
me_table_shuffles[i, cached_js_arr, :] = shuffle_vals + _noise_shuffles[i, cached_js_arr, :]
# Non-FFT-able pairs (cache entry was None): loop fallback
for j in uncached_js:
ts2 = ts_bunch2[j]
key2 = _get_ts_key(ts2)
pair_seed = seed + hash((key1, key2)) % 1000000 if seed is not None else None
pair_rng = np.random.RandomState(pair_seed)
me0 = get_sim(
ts1, ts2, metric, ds=ds,
shift=optimal_delays[i, j] // ds,
estimator=mi_estimator,
check_for_coincidence=True,
mi_estimator_kwargs=mi_estimator_kwargs,
)
me_table[i, j] = me0 + pair_rng.random() * noise_const
random_noise = pair_rng.random(size=nsh) * noise_const
for k, shift in enumerate(random_shifts[i, j, :]):
me = get_sim(
ts1, ts2, metric, ds=ds, shift=shift,
estimator=mi_estimator,
mi_estimator_kwargs=mi_estimator_kwargs,
)
me_table_shuffles[i, j, k] = me + random_noise[k]
else:
# No cache — per-pair loop with fresh FFT or loop computation
for j, ts2 in enumerate(ts_bunch2):
if mask[i, j] == 1:
key1 = _get_ts_key(ts1)
key2 = _get_ts_key(ts2)
pair_seed = seed + hash((key1, key2)) % 1000000 if seed is not None else None
pair_rng = np.random.RandomState(pair_seed)
fft_type = get_fft_type(ts1, ts2, metric, mi_estimator, nsh, engine)
if fft_type is not None:
# Unified FFT-accelerated path
data1, data2 = _extract_fft_data(ts1, ts2, fft_type, ds)
compute_fn = _FFT_COMPUTE[fft_type]
opt_shift = optimal_delays[i, j] // ds
me0 = compute_fn(data1, data2, np.array([opt_shift]))[0]
shuffle_mis = compute_fn(data1, data2, random_shifts[i, j, :])
me_table[i, j] = me0 + pair_rng.random() * noise_const
random_noise = pair_rng.random(size=nsh) * noise_const
me_table_shuffles[i, j, :] = shuffle_mis + random_noise
else:
# Original loop path (no FFT available)
me0 = get_sim(
ts1, ts2, metric, ds=ds,
shift=optimal_delays[i, j] // ds,
estimator=mi_estimator,
check_for_coincidence=True,
mi_estimator_kwargs=mi_estimator_kwargs,
)
me_table[i, j] = me0 + pair_rng.random() * noise_const
random_noise = pair_rng.random(size=nsh) * noise_const
for k, shift in enumerate(random_shifts[i, j, :]):
me = get_sim(
ts1, ts2, metric, ds=ds, shift=shift,
estimator=mi_estimator,
mi_estimator_kwargs=mi_estimator_kwargs,
)
me_table_shuffles[i, j, k] = me + random_noise[k]
me_total = np.dstack((me_table, me_table_shuffles))
return random_shifts, me_total
[docs]
def scan_pairs_parallel(
ts_bunch1,
ts_bunch2,
metric,
nsh,
optimal_delays,
mi_estimator="gcmi",
ds=1,
mask=None,
noise_const=DEFAULT_NOISE_AMPLITUDE,
seed=None,
n_jobs=-1,
engine="auto",
fft_cache: dict = None,
mi_estimator_kwargs=None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Calculate metric values and shuffles for time series pairs using parallel processing.
Parameters
----------
ts_bunch1 : list of TimeSeries
First set of time series.
ts_bunch2 : list of TimeSeries
Second set of time series.
metric : str
Similarity metric to compute:
- 'mi': Mutual information
- 'spearman': Spearman correlation
- Other metrics supported by get_sim function
nsh : int
Number of shuffles to perform.
optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2))
Pre-computed optimal delays for each pair.
mi_estimator : str, default='gcmi'
Mutual information estimator to use when metric='mi'.
Options: 'gcmi' (Gaussian copula) or 'ksg' (k-nearest neighbors).
ds : int, default=1
Downsampling factor.
mask : np.ndarray, optional
Binary mask of shape (len(ts_bunch1), len(ts_bunch2)).
0 = skip computation, 1 = compute. Default: all ones.
noise_const : float, default=1e-3
Small noise added to improve numerical stability.
seed : int, optional
Random seed for reproducibility.
n_jobs : int, default=-1
Number of parallel jobs. -1 uses all cores.
engine : {'auto', 'fft', 'loop'}, default='auto'
Computation engine for MI shuffles:
- 'auto': Use FFT when applicable (univariate continuous GCMI with nsh >= 50)
- 'fft': Force FFT (raises error if not applicable)
- 'loop': Force per-shift loop (original behavior)
fft_cache : dict, optional
Pre-computed FFT cache mapping (key1, key2) tuples to FFTCacheEntry objects.
Keys are stable identifiers from _get_ts_key(). If provided, avoids redundant
data extraction. If None, FFT type is computed fresh for each pair.
mi_estimator_kwargs : dict, optional
Additional keyword arguments passed to the MI estimator function.
Returns
-------
random_shifts : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh)
Random shifts used for shuffling.
me_total : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh+1)
Metric values. [:,:,0] contains true values, [:,:,1:] contains shuffles.
Raises
------
ValueError
If input validation fails or parameters are invalid.
Notes
-----
- Parallelization is done by splitting ts_bunch1 across workers
- Each worker handles a subset of ts_bunch1 against all of ts_bunch2
- Uses threading backend if PyTorch present (checked lazily), else loky
- Random seeding ensures reproducibility across different mask configurations
- FFT optimization provides ~100x speedup for univariate continuous GCMI
See Also
--------
scan_pairs : Sequential version of this function
scan_pairs_router : Wrapper that chooses between parallel and sequential
Examples
--------
>>> # Minimal example with 2x2 pairs
>>> import numpy as np
>>> from driada.information.info_base import TimeSeries
>>> np.random.seed(42) # For reproducibility
>>> # Small data: 2 neurons, 2 behaviors, 50 timepoints
>>> neurons = [TimeSeries(np.random.randn(50), discrete=False) for _ in range(2)]
>>> behaviors = [TimeSeries(np.random.randn(50), discrete=False) for _ in range(2)]
>>> delays = np.zeros((2, 2), dtype=int) # No delays
>>> # Just 5 shuffles for demonstration
>>> shifts, metrics = scan_pairs_parallel(neurons, behaviors, 'mi',
... 5, delays, n_jobs=1, seed=42)
>>> shifts.shape
(2, 2, 5)
>>> metrics.shape # Original + 5 shuffles = 6 total
(2, 2, 6)"""
# Validate inputs
validate_time_series_bunches(ts_bunch1, ts_bunch2)
validate_metric(metric)
validate_common_parameters(ds=ds, nsh=nsh, noise_const=noise_const)
n1 = len(ts_bunch1)
n2 = len(ts_bunch2)
# Validate optimal_delays shape
if optimal_delays.shape != (n1, n2):
raise ValueError(
f"optimal_delays shape {optimal_delays.shape} doesn't match expected ({n1}, {n2})"
)
me_total = np.zeros((n1, n2, nsh + 1), dtype=np.float32)
if n_jobs == -1:
n_jobs = min(multiprocessing.cpu_count(), n1)
# Initialize mask if None
if mask is None:
n1 = len(ts_bunch1)
n2 = len(ts_bunch2)
mask = np.ones((n1, n2))
# Pre-generate ALL random shifts upfront using stable key seeding
random_shifts = _generate_random_shifts_grid(
ts_bunch1, ts_bunch2, optimal_delays, nsh, seed if seed is not None else 0, ds
)
# Limit n_jobs to number of items to avoid empty worker splits
n_jobs_effective = min(n_jobs, len(ts_bunch1))
if n_jobs_effective < n_jobs:
import warnings
warnings.warn(
f"Requested {n_jobs} parallel jobs but only {len(ts_bunch1)} items to process. "
f"Using {n_jobs_effective} workers to avoid empty splits.",
UserWarning
)
# Split work across workers
split_ts_bunch1_inds = np.array_split(np.arange(len(ts_bunch1)), n_jobs_effective)
split_ts_bunch1 = [np.array(ts_bunch1)[idxs] for idxs in split_ts_bunch1_inds]
split_optimal_delays = [optimal_delays[idxs] for idxs in split_ts_bunch1_inds]
split_random_shifts = [random_shifts[idxs] for idxs in split_ts_bunch1_inds]
split_mask = [mask[idxs] for idxs in split_ts_bunch1_inds]
# Split cache per worker - each worker only gets entries it needs
# This avoids serializing the entire cache (potentially GBs) to each worker
split_caches = [
_extract_cache_subset(fft_cache, subset, ts_bunch2)
for subset in split_ts_bunch1
]
# Parallel execution with backend-specific config
with _parallel_executor(n_jobs_effective) as parallel:
parallel_result = parallel(
delayed(scan_pairs)(
small_ts_bunch,
ts_bunch2,
metric,
nsh,
split_optimal_delays[worker_idx],
split_random_shifts[worker_idx], # Pre-generated, pre-split shifts
mi_estimator,
ds=ds,
mask=split_mask[worker_idx],
noise_const=noise_const,
seed=seed,
enable_progressbar=False,
engine=engine,
fft_cache=split_caches[worker_idx],
mi_estimator_kwargs=mi_estimator_kwargs,
)
for worker_idx, small_ts_bunch in enumerate(split_ts_bunch1)
)
for i in range(n_jobs_effective):
inds_of_interest = split_ts_bunch1_inds[i]
random_shifts[inds_of_interest, :, :] = parallel_result[i][0][:, :, :]
me_total[inds_of_interest, :, :] = parallel_result[i][1][:, :, :]
return random_shifts, me_total
[docs]
def scan_pairs_router(
ts_bunch1,
ts_bunch2,
metric,
nsh,
optimal_delays,
mi_estimator="gcmi",
ds=1,
mask=None,
noise_const=DEFAULT_NOISE_AMPLITUDE,
seed=None,
enable_parallelization=True,
n_jobs=-1,
engine="auto",
fft_cache: dict = None,
mi_estimator_kwargs=None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Route metric computation to parallel or sequential implementation.
Parameters
----------
ts_bunch1 : list of TimeSeries
First set of time series.
ts_bunch2 : list of TimeSeries
Second set of time series.
metric : str
Similarity metric to compute:
- 'mi': Mutual information
- 'spearman': Spearman correlation
- Other metrics supported by get_sim function
nsh : int
Number of shuffles to perform.
optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2))
Pre-computed optimal delays for each pair.
mi_estimator : str, default='gcmi'
Mutual information estimator to use when metric='mi'.
Options: 'gcmi' (Gaussian copula) or 'ksg' (k-nearest neighbors).
ds : int, default=1
Downsampling factor.
mask : np.ndarray, optional
Binary mask of shape (len(ts_bunch1), len(ts_bunch2)).
0 = skip computation, 1 = compute. Default: all ones.
noise_const : float, default=1e-3
Small noise added to improve numerical stability.
seed : int, optional
Random seed for reproducibility.
enable_parallelization : bool, default=True
Whether to use parallel processing.
n_jobs : int, default=-1
Number of parallel jobs if parallelization enabled. -1 uses all cores.
engine : {'auto', 'fft', 'loop'}, default='auto'
Computation engine for MI shuffles:
- 'auto': Use FFT when applicable (univariate continuous GCMI with nsh >= 50)
- 'fft': Force FFT (raises error if not applicable)
- 'loop': Force per-shift loop (original behavior)
fft_cache : dict, optional
Pre-computed FFT cache mapping (global_i, j) tuples to FFTCacheEntry objects.
If provided, avoids redundant data extraction. If None, FFT type is computed
fresh for each pair. Use _build_fft_cache() to create this cache.
mi_estimator_kwargs : dict, optional
Additional keyword arguments passed to the MI estimator function.
Returns
-------
random_shifts : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh)
Random shifts used for shuffling.
me_total : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2), nsh+1)
Metric values. [:,:,0] contains true values, [:,:,1:] contains shuffles.
Notes
-----
This function automatically chooses between sequential and parallel
implementations based on the enable_parallelization flag. It's the
recommended entry point for scan_pairs functionality.
FFT optimization provides ~100x speedup for univariate continuous GCMI.
See Also
--------
scan_pairs : Sequential implementation
scan_pairs_parallel : Parallel implementation
Examples
--------
>>> # Router example - chooses sequential or parallel execution
>>> import numpy as np
>>> from driada.information.info_base import TimeSeries
>>> np.random.seed(42)
>>> # Minimal data for fast execution
>>> neurons = [TimeSeries(np.random.randn(30), discrete=False) for _ in range(2)]
>>> behaviors = [TimeSeries(np.random.randn(30), discrete=False) for _ in range(2)]
>>> delays = np.zeros((2, 2), dtype=int) # No delays
>>> # Use sequential mode (enable_parallelization=False)
>>> shifts, metrics = scan_pairs_router(neurons, behaviors, 'mi',
... 3, delays, enable_parallelization=False, seed=42)
>>> metrics.shape # 1 original + 3 shuffles = 4 total
(2, 2, 4)
>>> # First slice contains actual MI values
>>> metrics[:, :, 0].shape
(2, 2)"""
if enable_parallelization:
random_shifts, me_total = scan_pairs_parallel(
ts_bunch1,
ts_bunch2,
metric,
nsh,
optimal_delays,
mi_estimator,
ds=ds,
mask=mask,
noise_const=noise_const,
seed=seed,
n_jobs=n_jobs,
engine=engine,
fft_cache=fft_cache,
mi_estimator_kwargs=mi_estimator_kwargs,
)
else:
random_shifts, me_total = scan_pairs(
ts_bunch1,
ts_bunch2,
metric,
nsh,
optimal_delays,
random_shifts=None, # Generate shifts inside scan_pairs
mi_estimator=mi_estimator,
ds=ds,
mask=mask,
seed=seed,
noise_const=noise_const,
engine=engine,
fft_cache=fft_cache,
mi_estimator_kwargs=mi_estimator_kwargs,
)
return random_shifts, me_total
[docs]
def scan_stage(
ts_bunch1: list,
ts_bunch2: list,
config: StageConfig,
optimal_delays: np.ndarray,
metric: str,
mi_estimator: str,
metric_distr_type: str,
noise_const: float,
ds: int,
seed: int,
enable_parallelization: bool,
n_jobs: int,
engine: str,
fft_cache: dict = None,
verbose: bool = True,
mi_estimator_kwargs=None,
) -> tuple[dict, dict, dict]:
"""
Execute a single stage of INTENSE computation.
This function encapsulates the common logic between Stage 1 and Stage 2:
1. Scan pairs to compute metric values and shuffle distributions
2. Compute statistical tables from the results
3. Apply the appropriate criterion (Stage 1: rank-based filtering
using topk; Stage 2: p-value based with multiple comparison correction)
For Stage 2, the multiple comparison correction threshold is computed
internally from the stage statistics using config.pval_thr and
config.multicomp_correction.
Parameters
----------
ts_bunch1 : list of TimeSeries or MultiTimeSeries
First set of time series (typically neural signals).
ts_bunch2 : list of TimeSeries or MultiTimeSeries
Second set of time series (typically behavioral variables).
config : StageConfig
Configuration for this stage (stage number, n_shuffles, mask, topk, etc.).
optimal_delays : np.ndarray
Optimal delays array of shape (len(ts_bunch1), len(ts_bunch2)).
metric : str
Similarity metric to compute.
mi_estimator : str
Mutual information estimator ('gcmi' or 'ksg').
metric_distr_type : str
Distribution type for fitting shuffled metric values.
noise_const : float
Small noise amplitude added for numerical stability.
ds : int
Downsampling factor.
seed : int
Random seed for reproducibility.
enable_parallelization : bool
Whether to use parallel processing.
n_jobs : int
Number of parallel jobs if parallelization enabled.
engine : str
Computation engine ('auto', 'fft', 'loop').
fft_cache : dict, optional
Pre-computed FFT cache for accelerated computation.
verbose : bool, default=True
Whether to print stage information.
mi_estimator_kwargs : dict, optional
Additional keyword arguments passed to the MI estimator function.
Returns
-------
stage_stats : dict
Statistical results for all pairs from get_table_of_stats.
stage_significance : dict
Significance results for all pairs from apply_stage_criterion.
stage_info : dict
Additional information including:
- 'random_shifts': Random shifts array used for shuffling
- 'me_total': Full metric values array (true + shuffles)
- 'pass_mask': Binary mask of pairs that passed the criterion
- 'multicorr_thr': Multiple comparison threshold (Stage 2 only, None for Stage 1)
"""
n1 = len(ts_bunch1)
n2 = len(ts_bunch2)
if verbose:
print(f"Stage {config.stage_num}: {config.n_shuffles} shuffles")
# 1. Scan pairs
random_shifts, me_total = scan_pairs_router(
ts_bunch1,
ts_bunch2,
metric,
config.n_shuffles,
optimal_delays,
mi_estimator,
ds=ds,
mask=config.mask,
noise_const=noise_const,
seed=seed,
enable_parallelization=enable_parallelization,
n_jobs=n_jobs,
engine=engine,
fft_cache=fft_cache,
mi_estimator_kwargs=mi_estimator_kwargs,
)
# 2. Compute stats
stage_stats = get_table_of_stats(
me_total,
optimal_delays,
precomputed_mask=config.mask,
metric_distr_type=metric_distr_type,
nsh=config.n_shuffles,
stage=config.stage_num,
)
# 3. Apply criterion
if config.stage_num == 1:
stage_significance, pass_mask = apply_stage_criterion(
stage_stats,
stage_num=1,
n1=n1,
n2=n2,
n_shuffles=config.n_shuffles,
topk=config.topk,
)
multicorr_thr = None
else:
# Stage 2: compute multicorr_thr from p-values, then apply criterion
nhyp = int(np.sum(config.mask))
all_pvals = get_all_nonempty_pvals(stage_stats, range(n1), range(n2))
multicorr_thr = get_multicomp_correction_thr(
config.pval_thr,
mode=config.multicomp_correction,
all_pvals=all_pvals,
nhyp=nhyp,
)
stage_significance, pass_mask = apply_stage_criterion(
stage_stats,
stage_num=2,
n1=n1,
n2=n2,
n_shuffles=config.n_shuffles,
topk=config.topk,
multicorr_thr=multicorr_thr,
)
stage_info = {
"random_shifts": random_shifts,
"me_total": me_total,
"pass_mask": pass_mask,
"multicorr_thr": multicorr_thr,
}
return stage_stats, stage_significance, stage_info
[docs]
def compute_me_stats(
ts_bunch1,
ts_bunch2,
names1=None,
names2=None,
mode="two_stage",
metric="mi",
mi_estimator="gcmi",
mi_estimator_kwargs=None,
precomputed_mask_stage1=None,
precomputed_mask_stage2=None,
n_shuffles_stage1=100,
n_shuffles_stage2=10000,
metric_distr_type=DEFAULT_METRIC_DISTR_TYPE,
noise_ampl=DEFAULT_NOISE_AMPLITUDE,
ds=1,
topk1=1,
topk2=5,
multicomp_correction="holm",
pval_thr=0.01,
find_optimal_delays=False,
skip_delays=[],
shift_window=100,
verbose=True,
seed=None,
enable_parallelization=True,
n_jobs=-1,
duplicate_behavior="ignore",
engine="auto",
store_random_shifts=False,
profile=False,
):
"""
Calculates similarity metric statistics for TimeSeries or MultiTimeSeries pairs
Parameters
----------
ts_bunch1 : list of TimeSeries objects
First set of time series
ts_bunch2 : list of TimeSeries objects
Second set of time series
names1 : list of str, optional
names than will be given to time series from tsbunch1 in final results
names2 : list of str, optional
names than will be given to time series from tsbunch2 in final results
mode : str, default='two_stage'
Computation mode. Options:
- ``'stage1'``: preliminary scanning with n_shuffles_stage1 shuffles only.
Rejects strictly non-significant pairs, does not give definite results
about significance of the others.
- ``'stage2'``: skip stage 1, perform full-scale scanning (n_shuffles_stage2
shuffles) of all pairs. Gives definite results but can be very
time-consuming. Also reduces statistical power of multiple comparison
tests since the number of hypotheses is very high.
- ``'two_stage'``: prune non-significant pairs during stage 1 then
perform thorough testing for the rest during stage 2. Recommended.
metric : str, default='mi'
similarity metric between TimeSeries
mi_estimator : str, default='gcmi'
Mutual information estimator to use when metric='mi'. Options: 'gcmi' or 'ksg'
mi_estimator_kwargs : dict, optional
Additional keyword arguments passed to the MI estimator function.
precomputed_mask_stage1 : np.array, optional
precomputed mask for skipping some of possible pairs in stage 1.
Shape: (len(ts_bunch1), len(ts_bunch2))
0 in mask values means calculation will be skipped.
1 in mask values means calculation will proceed.
precomputed_mask_stage2 : np.array, optional
precomputed mask for skipping some of possible pairs in stage 2.
Shape: (len(ts_bunch1), len(ts_bunch2))
0 in mask values means calculation will be skipped.
1 in mask values means calculation will proceed.
n_shuffles_stage1 : int, default=100
number of shuffles for first stage
n_shuffles_stage2 : int, default=10000
number of shuffles for second stage
metric_distr_type : str, default="gamma_zi"
Distribution type for shuffled metric null distribution. Options:
- 'gamma_zi' (default): Zero-inflated gamma distribution. Explicitly models the probability
mass at zero that commonly occurs in MI null distributions. Provides superior goodness-of-fit
and accurate parameter estimation without requiring artificial noise.
- 'gamma': Standard gamma distribution with small noise added (noise_ampl) to handle zeros.
Provided for backward compatibility. Less statistically principled than 'gamma_zi'.
- Other scipy.stats distributions: 'lognorm', 'norm', etc. are supported but not recommended
for MI distributions.
noise_ampl : float, default=1e-3
Small noise amplitude, which is added to metrics to improve numerical fit
ds : int, default=1
Downsampling constant. Every "ds" point will be taken from the data time series.
topk1 : int, default=1
true MI for stage 1 should be among topk1 MI shuffles
topk2 : int, default=5
true MI for stage 2 should be among topk2 MI shuffles
multicomp_correction : str or None, default='holm'
type of multiple comparisons correction. Supported types are None (no correction),
"bonferroni", "holm", and "fdr_bh".
pval_thr : float, default=0.01
pvalue threshold. if multicomp_correction=None, this is a p-value for a single pair.
For FWER methods (bonferroni, holm), this is the family-wise error rate.
For FDR methods (fdr_bh), this is the false discovery rate.
find_optimal_delays : bool, default=False
Allows slight shifting (not more than +- shift_window) of time series,
selects a shift with the highest MI as default.
skip_delays : list, default=[]
List of indices from ts_bunch2 for which delays are not applied (set to 0).
Has no effect if find_optimal_delays = False
shift_window : int, default=100
Window for optimal shift search (frames). Optimal shift will lie in the range
-shift_window <= opt_shift <= shift_window
verbose : bool, default=True
whether to print intermediate information
seed : int, optional
random seed for reproducibility
enable_parallelization : bool, default=True
whether to use parallel processing for computations
n_jobs : int, default=-1
number of parallel jobs to use. -1 means use all available processors
duplicate_behavior : str, default='ignore'
How to handle duplicate TimeSeries in ts_bunch1 or ts_bunch2.
- 'ignore': Process duplicates normally (default)
- 'raise': Raise an error if duplicates are found
- 'warn': Print a warning but continue processing
engine : {'auto', 'fft', 'loop'}, default='auto'
Computation engine for MI shuffles:
- 'auto': Use FFT when applicable (univariate continuous GCMI with nsh >= 50)
- 'fft': Force FFT (raises error if not applicable)
- 'loop': Force per-shift loop (original behavior)
FFT optimization provides ~100x speedup for Stage 2.
store_random_shifts : bool, default=False
Whether to store the random shift indices used during shuffle computation.
When False (default), random_shifts1 and random_shifts2 arrays are not stored
in accumulated_info, saving significant memory (e.g., ~400MB for typical datasets).
Set to True if you need the shift indices for debugging or reproducibility analysis.
profile : bool, default=False
Whether to collect internal timing information. When True, accumulated_info
will include a 'timings' dict with execution times (in seconds) for:
- 'stage1_delay_optimization': delay optimization (if find_optimal_delays=True)
- 'stage1_pair_scanning': stage 1 pair scanning
- 'stage2_pair_scanning': stage 2 pair scanning (if applicable)
- 'total': sum of all timing sections
Returns
-------
stats : dict of dict of dicts
Outer dict keys: indices of tsbunch1 or names1, if given
Inner dict keys: indices or tsbunch2 or names2, if given
Last dict: dictionary of stats variables.
Can be easily converted to pandas DataFrame by pd.DataFrame(stats)
significance : dict of dict of dicts
Outer dict keys: indices of tsbunch1 or names1, if given
Inner dict keys: indices or tsbunch2 or names2, if given
Last dict: dictionary of significance-related variables.
Can be easily converted to pandas DataFrame by pd.DataFrame(significance)
accumulated_info : dict
Data collected during computation.
Raises
------
ValueError
If mode is not 'stage1', 'stage2', or 'two_stage'.
If multicomp_correction is not None, 'bonferroni', 'holm', or 'fdr_bh'.
If pval_thr is not between 0 and 1.
If duplicate_behavior is not 'ignore', 'raise', or 'warn'.
If duplicate TimeSeries found and duplicate_behavior='raise'.
Notes
-----
- When comparing the same bunch (ts_bunch1 is ts_bunch2), the diagonal
of masks is automatically set to 0 to avoid self-comparisons.
- In 'stage2' mode, dummy stage1 structures are created with placeholder values
to maintain consistency in the return format.
- For stage2, the final mask combines stage1 results with precomputed_mask_stage2
using logical AND.
- Input masks are never modified; copies are created when needed."""
# FUTURE: add automatic min_shifts from autocorrelation time
# Validate inputs
validate_time_series_bunches(ts_bunch1, ts_bunch2)
validate_metric(metric)
validate_common_parameters(shift_window=shift_window, ds=ds, noise_const=noise_ampl)
# Validate mode
if mode not in ["stage1", "stage2", "two_stage"]:
raise ValueError(f"mode must be 'stage1', 'stage2', or 'two_stage', got '{mode}'")
# Validate multicomp_correction
if multicomp_correction not in [None, "bonferroni", "holm", "fdr_bh"]:
raise ValueError(f"Unknown multiple comparison correction method: '{multicomp_correction}'")
# Validate pval_thr
if not 0 < pval_thr < 1:
raise ValueError(f"pval_thr must be between 0 and 1, got {pval_thr}")
# Validate stage-specific parameters
validate_common_parameters(nsh=n_shuffles_stage1)
validate_common_parameters(nsh=n_shuffles_stage2)
# Cap n_shuffles to available unique circular shifts
n_frames = ts_bunch1[0].data.shape[-1]
n_shifts = n_frames // ds
if n_shifts > 0:
if n_shuffles_stage1 > n_shifts:
n_shuffles_stage1 = n_shifts
if n_shuffles_stage2 > n_shifts:
n_shuffles_stage2 = n_shifts
accumulated_info = dict()
timings = {} if profile else None
# Temporary naming: Save original names and assign temporary names if missing
# This ensures all TimeSeries have names for FFT cache keys
original_names1 = [ts.name if hasattr(ts, 'name') else None for ts in ts_bunch1]
original_names2 = [ts.name if hasattr(ts, 'name') else None for ts in ts_bunch2]
# Assign temporary names if missing
for i, ts in enumerate(ts_bunch1):
if not hasattr(ts, 'name') or ts.name is None or ts.name == '':
# Use names1 if provided, else generate temp name
ts.name = str(names1[i]) if names1 and i < len(names1) else f"_ts1_{i}"
for j, ts in enumerate(ts_bunch2):
if not hasattr(ts, 'name') or ts.name is None or ts.name == '':
# Use names2 if provided, else generate temp name
ts.name = str(names2[j]) if names2 and j < len(names2) else f"_ts2_{j}"
try:
# Check if we're comparing the same bunch with itself
same_data_bunch = ts_bunch1 is ts_bunch2
n1 = len(ts_bunch1)
n2 = len(ts_bunch2)
if precomputed_mask_stage1 is None:
precomputed_mask_stage1 = np.ones((n1, n2))
else:
# Create a copy to avoid modifying the input
precomputed_mask_stage1 = precomputed_mask_stage1.copy()
if precomputed_mask_stage2 is None:
precomputed_mask_stage2 = np.ones((n1, n2))
else:
# Create a copy to avoid modifying the input
precomputed_mask_stage2 = precomputed_mask_stage2.copy()
# If comparing the same bunch with itself, mask out the diagonal
# to avoid computing MI of a TimeSeries with itself at zero shift
if same_data_bunch:
np.fill_diagonal(precomputed_mask_stage1, 0)
np.fill_diagonal(precomputed_mask_stage2, 0)
# Handle duplicate TimeSeries based on duplicate_behavior parameter
if duplicate_behavior in ["raise", "warn"]:
# Check for duplicates in ts_bunch1
ts1_ids = []
for ts in ts_bunch1:
ts_id = id(ts.data) if hasattr(ts, "data") else id(ts)
ts1_ids.append(ts_id)
if len(set(ts1_ids)) < len(ts1_ids):
msg = "Duplicate TimeSeries objects found in ts_bunch1"
if duplicate_behavior == "raise":
raise ValueError(msg)
else: # warn
print(f"Warning: {msg}")
# Check for duplicates in ts_bunch2
ts2_ids = []
for ts in ts_bunch2:
ts_id = id(ts.data) if hasattr(ts, "data") else id(ts)
ts2_ids.append(ts_id)
if len(set(ts2_ids)) < len(ts2_ids):
msg = "Duplicate TimeSeries objects found in ts_bunch2"
if duplicate_behavior == "raise":
raise ValueError(msg)
else: # warn
print(f"Warning: {msg}")
optimal_delays = np.zeros((n1, n2), dtype=int)
# Validate skip_delays indices before use
if skip_delays:
invalid_indices = [i for i in skip_delays if i < 0 or i >= len(ts_bunch2)]
if invalid_indices:
raise ValueError(
f"skip_delays contains invalid indices {invalid_indices}. "
f"Valid range: [0, {len(ts_bunch2)-1}] for {len(ts_bunch2)} features."
)
ts_with_delays = [ts for _, ts in enumerate(ts_bunch2) if not skip_delays or _ not in skip_delays]
ts_with_delays_inds = np.array([_ for _, ts in enumerate(ts_bunch2) if not skip_delays or _ not in skip_delays])
# Combine masks: cache any pair needed by either stage
cache_mask = np.maximum(precomputed_mask_stage1, precomputed_mask_stage2)
# Build FFT cache once at the start for reuse across delays + stages
if verbose:
n_to_cache = int(np.sum(cache_mask))
n_total = len(ts_bunch1) * len(ts_bunch2)
print(f"Building FFT cache for {n_to_cache}/{n_total} pairs (engine={engine})...")
with _timed_section(timings, 'fft_cache_building'):
fft_cache, fft_type_counts = _build_fft_cache(
ts_bunch1, ts_bunch2, metric, mi_estimator, ds, engine,
n_jobs=n_jobs if enable_parallelization else 1,
pair_mask=cache_mask,
)
# Store FFT type counts for profiling
if profile and fft_type_counts:
timings['fft_type_counts'] = fft_type_counts
with _timed_section(timings, 'stage1_delay_optimization'):
if find_optimal_delays and len(ts_with_delays) > 0:
# Local import to avoid circular dependency (delay.py imports from this module)
from .delay import calculate_optimal_delays, calculate_optimal_delays_parallel
# Use unified fft_cache - no need for separate delay cache
# Since cache uses stable keys, ts_with_delays objects are already cached
if enable_parallelization:
optimal_delays_res = calculate_optimal_delays_parallel(
ts_bunch1,
ts_with_delays,
metric,
shift_window,
ds,
verbose=verbose,
n_jobs=n_jobs,
mi_estimator=mi_estimator,
engine=engine,
fft_cache=fft_cache,
mi_estimator_kwargs=mi_estimator_kwargs,
)
else:
optimal_delays_res = calculate_optimal_delays(
ts_bunch1,
ts_with_delays,
metric,
shift_window,
ds,
verbose=verbose,
mi_estimator=mi_estimator,
engine=engine,
fft_cache=fft_cache,
mi_estimator_kwargs=mi_estimator_kwargs,
)
optimal_delays[:, ts_with_delays_inds] = optimal_delays_res
accumulated_info["optimal_delays"] = optimal_delays
# Initialize masks based on mode
if mode == "stage2":
# For stage2-only mode, assume all pairs pass stage 1
mask_from_stage1 = np.ones((n1, n2))
else:
mask_from_stage1 = np.zeros((n1, n2))
mask_from_stage2 = np.zeros((n1, n2))
nhyp = n1 * n2
# Conditional noise based on distribution type
# ZIG handles zeros explicitly, so no noise needed
noise_const = 0 if metric_distr_type == "gamma_zi" else noise_ampl
if mode in ["two_stage", "stage1"]:
npairs_to_check1 = int(np.sum(precomputed_mask_stage1))
if verbose:
print(f"Starting stage 1 scanning for {npairs_to_check1}/{nhyp} possible pairs")
with _timed_section(timings, 'stage1_pair_scanning'):
# STAGE 1 - primary scanning using scan_stage abstraction
config_stage1 = StageConfig(
stage_num=1,
n_shuffles=n_shuffles_stage1,
mask=precomputed_mask_stage1,
topk=topk1,
)
stage_1_stats, stage_1_significance, stage_1_info = scan_stage(
ts_bunch1,
ts_bunch2,
config_stage1,
optimal_delays,
metric=metric,
mi_estimator=mi_estimator,
metric_distr_type=metric_distr_type,
noise_const=noise_const,
ds=ds,
seed=seed,
enable_parallelization=enable_parallelization,
n_jobs=n_jobs,
engine=engine,
fft_cache=fft_cache,
verbose=False, # We handle verbose output here
mi_estimator_kwargs=mi_estimator_kwargs,
)
# Extract results from scan_stage
random_shifts1 = stage_1_info["random_shifts"]
me_total1 = stage_1_info["me_total"]
mask_from_stage1 = stage_1_info["pass_mask"]
# Convert to per-quantity tables for accumulated_info
stage_1_stats_per_quantity = nested_dict_to_seq_of_tables(
stage_1_stats, ordered_names1=range(n1), ordered_names2=range(n2)
)
stage_1_significance_per_quantity = nested_dict_to_seq_of_tables(
stage_1_significance, ordered_names1=range(n1), ordered_names2=range(n2)
)
stage1_info = {
"stage_1_significance": stage_1_significance_per_quantity,
"stage_1_stats": stage_1_stats_per_quantity,
"me_total1": me_total1,
}
if store_random_shifts:
stage1_info["random_shifts1"] = random_shifts1
accumulated_info.update(stage1_info)
nhyp = int(np.sum(mask_from_stage1)) # number of hypotheses for further statistical testing
accumulated_info['n_significant_stage1'] = nhyp
if verbose:
print("Stage 1 results:")
print(
f"{nhyp/n1/n2*100:.2f}% ({nhyp}/{n1*n2}) of possible pairs identified as candidates"
)
if mode == "stage1" or nhyp == 0:
final_stats = add_names_to_nested_dict(stage_1_stats, names1, names2)
final_significance = add_names_to_nested_dict(stage_1_significance, names1, names2)
if profile:
timings['total'] = sum(v for v in timings.values() if isinstance(v, (int, float)))
accumulated_info['timings'] = timings
return final_stats, final_significance, accumulated_info
elif mode == "stage2":
# For stage2-only mode, create empty stage 1 structures
stage_1_stats = populate_nested_dict(dict(), range(n1), range(n2))
stage_1_significance = populate_nested_dict(dict(), range(n1), range(n2))
# Set all pairs as passing stage 1 with placeholder values
for i in range(n1):
for j in range(n2):
stage_1_stats[i][j] = {"pre_rval": None, "pre_pval": None}
stage_1_significance[i][j]["stage1"] = True
# Now proceed with stage 2
if mode in ["two_stage", "stage2"]:
# STAGE 2 - full-scale scanning
combined_mask_for_stage_2 = np.ones((n1, n2))
combined_mask_for_stage_2[np.where(mask_from_stage1 == 0)] = (
0 # exclude non-significant pairs from stage1
)
combined_mask_for_stage_2[np.where(precomputed_mask_stage2 == 0)] = (
0 # exclude precomputed stage 2 pairs
)
npairs_to_check2 = int(np.sum(combined_mask_for_stage_2))
if verbose:
print(f"Starting stage 2 scanning for {npairs_to_check2}/{nhyp} possible pairs")
with _timed_section(timings, 'stage2_pair_scanning'):
# STAGE 2 using scan_stage abstraction
config_stage2 = StageConfig(
stage_num=2,
n_shuffles=n_shuffles_stage2,
mask=combined_mask_for_stage_2,
topk=topk2,
pval_thr=pval_thr,
multicomp_correction=multicomp_correction,
)
stage_2_stats, stage_2_significance, stage_2_info = scan_stage(
ts_bunch1,
ts_bunch2,
config_stage2,
optimal_delays,
metric=metric,
mi_estimator=mi_estimator,
metric_distr_type=metric_distr_type,
noise_const=noise_const,
ds=ds,
seed=seed,
enable_parallelization=enable_parallelization,
n_jobs=n_jobs,
engine=engine,
fft_cache=fft_cache,
verbose=False, # We handle verbose output here
mi_estimator_kwargs=mi_estimator_kwargs,
)
# Extract results from scan_stage
random_shifts2 = stage_2_info["random_shifts"]
me_total2 = stage_2_info["me_total"]
mask_from_stage2 = stage_2_info["pass_mask"]
multicorr_thr = stage_2_info["multicorr_thr"]
# Convert to per-quantity tables for accumulated_info
stage_2_stats_per_quantity = nested_dict_to_seq_of_tables(
stage_2_stats, ordered_names1=range(n1), ordered_names2=range(n2)
)
stage_2_significance_per_quantity = nested_dict_to_seq_of_tables(
stage_2_significance, ordered_names1=range(n1), ordered_names2=range(n2)
)
stage2_info = {
"stage_2_significance": stage_2_significance_per_quantity,
"stage_2_stats": stage_2_stats_per_quantity,
"me_total2": me_total2,
"corrected_pval_thr": multicorr_thr,
"group_pval_thr": pval_thr,
}
if store_random_shifts:
stage2_info["random_shifts2"] = random_shifts2
accumulated_info.update(stage2_info)
num2 = int(np.sum(mask_from_stage2))
if verbose:
print("Stage 2 results:")
print(
f"{num2/n1/n2*100:.2f}% ({num2}/{n1*n2}) of possible pairs identified as significant"
)
# Always merge stats for consistency
merged_stats = merge_stage_stats(stage_1_stats, stage_2_stats)
merged_significance = merge_stage_significance(stage_1_significance, stage_2_significance)
final_stats = add_names_to_nested_dict(merged_stats, names1, names2)
final_significance = add_names_to_nested_dict(merged_significance, names1, names2)
if profile:
timings['total'] = sum(v for v in timings.values() if isinstance(v, (int, float)))
accumulated_info['timings'] = timings
return final_stats, final_significance, accumulated_info
finally:
# Free FFT cache memory explicitly to prevent accumulation
if 'fft_cache' in locals() and fft_cache is not None:
fft_cache.clear()
del fft_cache
# Restore original names to leave objects unchanged
for i, ts in enumerate(ts_bunch1):
ts.name = original_names1[i]
for j, ts in enumerate(ts_bunch2):
ts.name = original_names2[j]