Source code for driada.intense.delay

"""
Delay optimization for INTENSE analysis.

Provides functions to find optimal temporal delays between pairs of time series
(e.g. neural signals and behavioral variables) by maximizing a similarity metric
across a range of shifts.
"""

import multiprocessing
import warnings

import numpy as np
import tqdm
from joblib import delayed

from ..information.info_base import get_sim
from ..utils.parallel import parallel_executor as _parallel_executor
from .validation import (
    validate_common_parameters,
    validate_metric,
    validate_time_series_bunches,
)
from .fft import (
    MIN_SHIFTS_FOR_FFT_DELAYS,
    _extract_fft_data,
    _FFT_COMPUTE,
    _get_ts_key,
    get_fft_type,
)
from .intense_base import _extract_cache_subset


[docs] def calculate_optimal_delays( ts_bunch1, ts_bunch2, metric, shift_window, ds, verbose=True, enable_progressbar=True, mi_estimator="gcmi", engine="auto", fft_cache: dict = None, mi_estimator_kwargs=None, ) -> np.ndarray: """ Calculate optimal temporal delays between pairs of time series. Finds the delay that maximizes the similarity metric between each pair of time series from ts_bunch1 and ts_bunch2. This accounts for temporal offsets in neural responses relative to behavioral variables. Parameters ---------- ts_bunch1 : list of TimeSeries First set of time series (typically neural signals). ts_bunch2 : list of TimeSeries Second set of time series (typically behavioral variables). metric : str Similarity metric to maximize. See validate_metric for supported options. shift_window : int Maximum shift to test in each direction (frames). Will test shifts from -shift_window to +shift_window inclusive. ds : int Downsampling factor. Every ds-th point is used from the time series. verbose : bool, default=True Whether to print progress information. enable_progressbar : bool, default=True Whether to show progress bar. mi_estimator : str, default='gcmi' MI estimator to use when metric='mi'. Options: 'gcmi' or 'ksg'. engine : {'auto', 'fft', 'loop'}, default='auto' Computation engine for delay optimization: - 'auto': Use FFT when applicable (univariate continuous GCMI with >= 20 shifts) - 'fft': Force FFT (raises error if not applicable) - 'loop': Force per-shift loop (original behavior) fft_cache : dict, optional Pre-computed FFT cache from _build_fft_cache. Keys are (key1, key2) tuples using stable identifiers from _get_ts_key(). If provided, avoids redundant data extraction. mi_estimator_kwargs : dict, optional Additional keyword arguments passed to the MI estimator function. Returns ------- optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2)) Optimal delay (in frames) for each pair. Positive values indicate that ts2 leads ts1, negative values indicate ts1 leads ts2. Notes ----- - With FFT engine: O(n1 * n2 * n log n) where n is downsampled time series length - With loop engine: O(n1 * n2 * shifts * n) where shifts = 2 * shift_window / ds - FFT provides ~10-20x speedup for typical delay windows (100-200 shifts) - The optimal delay is found by exhaustive search over all possible shifts - Memory efficient: only stores final optimal delays, not all tested values Examples -------- >>> # Create minimal time series with known phase shift >>> import numpy as np >>> from driada.information.info_base import TimeSeries >>> # 100 points is enough to demonstrate functionality >>> t = np.linspace(0, 2*np.pi, 100) >>> # Create two sine waves with 5-sample phase shift >>> data1 = np.sin(t) >>> data2 = np.sin(t + np.pi/4) # phase shifted signal >>> ts1 = TimeSeries(data1, discrete=False) >>> ts2 = TimeSeries(data2, discrete=False) >>> # Find optimal delay with small window for speed >>> delays = calculate_optimal_delays([ts1], [ts2], 'mi', ... shift_window=5, ds=1, verbose=False) >>> delays.shape (1, 1) >>> # The delay captures the phase relationship >>> -5 <= delays[0, 0] <= 5 True""" # Validate inputs validate_time_series_bunches(ts_bunch1, ts_bunch2) validate_metric(metric) validate_common_parameters(shift_window=shift_window, ds=ds) if verbose: print("Calculating optimal delays:") optimal_delays = np.zeros((len(ts_bunch1), len(ts_bunch2)), dtype=int) shifts = np.arange(-shift_window, shift_window + ds, ds) // ds # Pre-compute FFT types for each feature (only if no cache provided) if fft_cache is None and len(ts_bunch1) > 0: fft_types = [ get_fft_type( ts_bunch1[0], ts2, metric, mi_estimator, len(shifts), engine, for_delays=True ) for ts2 in ts_bunch2 ] else: fft_types = None for i, ts1 in tqdm.tqdm( enumerate(ts_bunch1), total=len(ts_bunch1), disable=not enable_progressbar ): for j, ts2 in enumerate(ts_bunch2): # Check cache first (uses stable keys) cache_entry = fft_cache.get( (_get_ts_key(ts1), _get_ts_key(ts2)) ) if fft_cache else None if cache_entry is not None: # Use cached MI values - just index! mi_all = cache_entry.mi_all n = len(mi_all) # Convert shifts to non-negative indices fft_shifts = np.where(shifts >= 0, shifts, n + shifts).astype(int) # Look up MI at requested shifts mi_values = mi_all[fft_shifts] best_idx = np.argmax(mi_values) optimal_delays[i, j] = int(shifts[best_idx] * ds) elif fft_types is not None and fft_types[j] is not None: # No cache but FFT applicable - extract data fresh fft_type = fft_types[j] data1, data2 = _extract_fft_data(ts1, ts2, fft_type, ds) compute_fn = _FFT_COMPUTE[fft_type] n = len(data1) if data1.ndim == 1 else data1.shape[-1] fft_shifts = np.where(shifts >= 0, shifts, n + shifts).astype(int) mi_values = compute_fn(data1, data2, fft_shifts) best_idx = np.argmax(mi_values) optimal_delays[i, j] = int(shifts[best_idx] * ds) else: # Loop fallback shifted_me = [] for shift in shifts: lag_me = get_sim( ts1, ts2, metric, ds=ds, shift=int(shift), estimator=mi_estimator, mi_estimator_kwargs=mi_estimator_kwargs, ) shifted_me.append(lag_me) best_shift = shifts[np.argmax(shifted_me)] optimal_delays[i, j] = int(best_shift * ds) return optimal_delays
[docs] def calculate_optimal_delays_parallel( ts_bunch1, ts_bunch2, metric, shift_window, ds, verbose=True, n_jobs=-1, mi_estimator="gcmi", engine="auto", fft_cache: dict = None, mi_estimator_kwargs=None, ) -> np.ndarray: """ Calculate optimal temporal delays between pairs of time series using parallel processing. Parallel version of calculate_optimal_delays that distributes computation across multiple CPU cores for improved performance with large datasets. Parameters ---------- ts_bunch1 : list of TimeSeries First set of time series (typically neural signals). ts_bunch2 : list of TimeSeries Second set of time series (typically behavioral variables). metric : str Similarity metric to maximize. See validate_metric for supported options. shift_window : int Maximum shift to test in each direction (frames). Will test shifts from -shift_window to +shift_window inclusive. ds : int Downsampling factor. Every ds-th point is used from the time series. verbose : bool, default=True Whether to print progress information. n_jobs : int, default=-1 Number of parallel jobs to run. -1 uses all available cores. mi_estimator : str, default='gcmi' MI estimator to use when metric='mi'. Options: 'gcmi' or 'ksg'. engine : {'auto', 'fft', 'loop'}, default='auto' Computation engine for delay optimization: - 'auto': Use FFT when applicable (univariate continuous GCMI with >= 20 shifts) - 'fft': Force FFT (raises error if not applicable) - 'loop': Force per-shift loop (original behavior) fft_cache : dict, optional Pre-computed FFT cache from _build_fft_cache. Keys are (key1, key2) tuples using stable identifiers from _get_ts_key(). Passed to each worker for cache reuse. mi_estimator_kwargs : dict, optional Additional keyword arguments passed to the MI estimator function. Returns ------- optimal_delays : np.ndarray of shape (len(ts_bunch1), len(ts_bunch2)) Optimal delay (in frames) for each pair. Positive values indicate that ts2 leads ts1, negative values indicate ts1 leads ts2. Notes ----- - Parallelization is done by splitting ts_bunch1 across workers - Each worker processes a subset of ts_bunch1 against all of ts_bunch2 - Memory usage scales with number of workers - Speedup is typically sublinear due to overhead and memory bandwidth - FFT optimization within each worker provides additional speedup See Also -------- calculate_optimal_delays : Sequential version of this function Examples -------- >>> # Demonstrate parallel processing with minimal data >>> import numpy as np >>> from driada.information.info_base import TimeSeries >>> # Create 3 neurons and 2 behaviors with 50 timepoints each >>> np.random.seed(42) # For reproducible example >>> neurons = [TimeSeries(np.random.randn(50), discrete=False) for _ in range(3)] >>> behaviors = [TimeSeries(np.random.randn(50), discrete=False) for _ in range(2)] >>> # Use 2 cores with small shift window >>> delays = calculate_optimal_delays_parallel(neurons, behaviors, 'mi', ... shift_window=3, ds=1, n_jobs=2, verbose=False) >>> delays.shape (3, 2) >>> # All delays should be within the window >>> np.all(np.abs(delays) <= 3) True""" # Validate inputs validate_time_series_bunches(ts_bunch1, ts_bunch2) validate_metric(metric) validate_common_parameters(shift_window=shift_window, ds=ds) if verbose: print("Calculating optimal delays in parallel mode:") optimal_delays = np.zeros((len(ts_bunch1), len(ts_bunch2)), dtype=int) if n_jobs == -1: n_jobs = min(multiprocessing.cpu_count(), len(ts_bunch1)) # Limit n_jobs to number of items to avoid empty worker splits n_jobs_effective = min(n_jobs, len(ts_bunch1)) if n_jobs_effective < n_jobs and verbose: warnings.warn( f"Requested {n_jobs} parallel jobs but only {len(ts_bunch1)} items to process. " f"Using {n_jobs_effective} workers to avoid empty splits.", UserWarning ) split_ts_bunch1_inds = np.array_split(np.arange(len(ts_bunch1)), n_jobs_effective) split_ts_bunch1 = [np.array(ts_bunch1)[idxs] for idxs in split_ts_bunch1_inds] # Split cache per worker - each worker only gets entries it needs # This avoids serializing the entire cache (potentially GBs) to each worker split_caches = [ _extract_cache_subset(fft_cache, subset, ts_bunch2) for subset in split_ts_bunch1 ] # Parallel execution with backend-specific config with _parallel_executor(n_jobs_effective) as parallel: parallel_delays = parallel( delayed(calculate_optimal_delays)( small_ts_bunch, ts_bunch2, metric, shift_window, ds, verbose=False, enable_progressbar=False, mi_estimator=mi_estimator, engine=engine, fft_cache=split_cache, mi_estimator_kwargs=mi_estimator_kwargs, ) for small_ts_bunch, split_cache in zip(split_ts_bunch1, split_caches) ) for i, pd in enumerate(parallel_delays): inds_of_interest = split_ts_bunch1_inds[i] optimal_delays[inds_of_interest, :] = pd return optimal_delays