"""
FFT dispatch, data extraction, and caching for INTENSE.
Provides FFT type classification, data extraction helpers, and cache building
functions used by the INTENSE pipeline and delay optimization.
"""
import multiprocessing
import numpy as np
from dataclasses import dataclass
from joblib import delayed
from ..information.info_base import (
TimeSeries,
MultiTimeSeries,
compute_mi_batch_fft,
compute_mi_gd_fft,
compute_mi_mts_fft,
compute_mi_mts_mts_fft,
compute_mi_mts_discrete_fft,
compute_mi_dd_fft,
compute_pearson_batch_fft,
compute_av_batch_fft,
)
from ..information.info_fft_utils import REG_VARIANCE_THRESHOLD
from ..information.info_utils import py_fast_digamma
from ..utils.parallel import parallel_executor as _parallel_executor
from .validation import (
validate_time_series_bunches,
validate_metric,
validate_common_parameters,
)
_LN2 = np.log(2)
# Minimum number of shuffles to benefit from FFT optimization
# FFT is always beneficial due to high per-call overhead in loop fallback
MIN_SHUFFLES_FOR_FFT = 1
# Minimum number of shifts to benefit from FFT in delay optimization
# FFT is always beneficial due to high per-call overhead in loop fallback
MIN_SHIFTS_FOR_FFT_DELAYS = 1
# Maximum dimensions for FFT acceleration of MultiTimeSeries
MAX_FFT_MTS_DIMENSIONS = 3
MAX_MTS_MTS_FFT_DIMENSIONS = 6 # Total d1+d2 limit for MTS-MTS pairs
# FFT type constants
FFT_CONTINUOUS = "cc" # Continuous-continuous (univariate 1D-1D)
FFT_DISCRETE = "gd" # Gaussian-discrete (one discrete, one continuous)
FFT_DISCRETE_DISCRETE = "dd" # Discrete-discrete (both variables discrete)
FFT_MULTIVARIATE = "mts" # MultiTimeSeries + univariate TimeSeries
FFT_MTS_MTS = "mts_mts" # MultiTimeSeries + MultiTimeSeries
FFT_MTS_DISCRETE = "mts_discrete" # MultiTimeSeries + discrete
FFT_PEARSON_CONTINUOUS = "pearson_cc" # Pearson correlation (continuous-continuous)
FFT_PEARSON_DISCRETE = "pearson_gd" # Pearson correlation (continuous-discrete, point-biserial)
FFT_AV_DISCRETE = "av_gd" # Activity ratio (continuous-discrete, binary)
def _get_ts_key(ts):
"""
Get stable identifier for TimeSeries based on its name attribute.
This function generates a stable key for cache lookups that works across
both threading and loky (pickling) joblib backends.
All TimeSeries/MultiTimeSeries objects MUST have a name attribute set.
Names should be assigned at creation or by INTENSE pipelines before
calling compute_me_stats.
Parameters
----------
ts : TimeSeries or MultiTimeSeries
Time series object to get key for.
Returns
-------
str
The name attribute of the TimeSeries.
Raises
------
ValueError
If the TimeSeries does not have a name attribute or the name is empty.
Notes
-----
- Keys are stable across pickling (names are pickled with objects)
- Assumes names are unique within a bundle
- Used for FFT cache keys and random seed generation
- INTENSE pipelines assign temporary names to unnamed objects on entry
"""
if hasattr(ts, 'name') and ts.name:
return ts.name
# Should never reach here if naming strategy is complete
raise ValueError(
f"TimeSeries missing name attribute. "
f"All TimeSeries/MultiTimeSeries must have names for FFT cache and seeding. "
f"Shape: {ts.data.shape}, discrete: {ts.discrete}"
)
[docs]
@dataclass
class FFTCacheEntry:
"""Cache entry for pre-computed MI values.
Stores MI values for ALL possible shifts, enabling O(1) lookup
without redundant FFT computation. The FFT is computed once when
building the cache, then MI for any shift is just array indexing.
Attributes
----------
fft_type : str
FFT type constant (FFT_CONTINUOUS, FFT_DISCRETE, FFT_MULTIVARIATE).
mi_all : np.ndarray
MI values for ALL n shifts (shape: (n,)).
"""
fft_type: str
mi_all: np.ndarray
[docs]
def get_fft_type(
ts1,
ts2,
metric: str,
mi_estimator: str,
count: int,
engine: str,
for_delays: bool = False,
):
"""Determine which FFT optimization to use for a time series pair.
Unified function that replaces _should_use_fft, _should_use_fft_gd,
_should_use_fft_mts, and _should_use_fft_for_delays.
Parameters
----------
ts1 : TimeSeries or MultiTimeSeries
First time series.
ts2 : TimeSeries or MultiTimeSeries
Second time series.
metric : str
Similarity metric being used.
mi_estimator : str
MI estimator ('gcmi' or 'ksg').
count : int
Number of shuffles (nsh) or shifts (for delay optimization).
engine : str
Computation engine: 'auto', 'fft', or 'loop'.
for_delays : bool
If True, use delay optimization threshold (MIN_SHIFTS_FOR_FFT_DELAYS).
Default False uses shuffle threshold (MIN_SHUFFLES_FOR_FFT).
Returns
-------
str or None
FFT type constant (FFT_CONTINUOUS, FFT_DISCRETE, FFT_MULTIVARIATE),
or None for loop fallback.
Raises
------
ValueError
If engine='fft' but no FFT optimization is applicable.
"""
# Early exit for loop engine
if engine == "loop":
return None
# FFT works with MI (GCMI estimator), fast_pearsonr, and av
if metric == "mi" and mi_estimator == "gcmi":
pass # proceed to type classification
elif metric == "fast_pearsonr":
pass # proceed — continuous-continuous only
elif metric == "av":
pass # proceed — binary-discrete + continuous only
else:
if engine == "fft":
raise ValueError(
f"engine='fft' requires metric='mi' (with mi_estimator='gcmi'), "
f"'fast_pearsonr', or 'av'. "
f"Got: metric={metric}, mi_estimator={mi_estimator}"
)
return None
# Determine threshold based on context
threshold = MIN_SHIFTS_FOR_FFT_DELAYS if for_delays else MIN_SHUFFLES_FOR_FFT
# Classify the pair type and check applicability
fft_type = None
error_msg = None
# For fast_pearsonr, only univariate continuous-continuous is supported via FFT
is_pearson = (metric == "fast_pearsonr")
is_av = (metric == "av")
# Check for MultiTimeSeries + univariate TimeSeries pair
if isinstance(ts1, MultiTimeSeries) and isinstance(ts2, TimeSeries) and not isinstance(ts2, MultiTimeSeries):
if is_pearson or is_av:
error_msg = f"FFT {metric} only supports univariate TimeSeries pairs"
else:
mts, ts = ts1, ts2
# Check for MTS + discrete pair first
if not mts.discrete and ts.discrete and mts.data.shape[0] <= MAX_FFT_MTS_DIMENSIONS:
fft_type = FFT_MTS_DISCRETE
# Then check for MTS + continuous pair
elif not mts.discrete and not ts.discrete and mts.data.shape[0] <= MAX_FFT_MTS_DIMENSIONS:
fft_type = FFT_MULTIVARIATE
else:
error_msg = (
f"MultiTimeSeries FFT requires continuous MTS with d <= {MAX_FFT_MTS_DIMENSIONS} "
f"paired with either continuous or discrete TimeSeries. "
f"Got: MTS discrete={mts.discrete}, TS discrete={ts.discrete}, MTS shape={mts.data.shape}"
)
elif isinstance(ts2, MultiTimeSeries) and isinstance(ts1, TimeSeries) and not isinstance(ts1, MultiTimeSeries):
if is_pearson or is_av:
error_msg = f"FFT {metric} only supports univariate TimeSeries pairs"
else:
mts, ts = ts2, ts1
# Check for MTS + discrete pair first
if not mts.discrete and ts.discrete and mts.data.shape[0] <= MAX_FFT_MTS_DIMENSIONS:
fft_type = FFT_MTS_DISCRETE
# Then check for MTS + continuous pair
elif not mts.discrete and not ts.discrete and mts.data.shape[0] <= MAX_FFT_MTS_DIMENSIONS:
fft_type = FFT_MULTIVARIATE
else:
error_msg = (
f"MultiTimeSeries FFT requires continuous MTS with d <= {MAX_FFT_MTS_DIMENSIONS} "
f"paired with either continuous or discrete TimeSeries. "
f"Got: MTS discrete={mts.discrete}, TS discrete={ts.discrete}, MTS shape={mts.data.shape}"
)
# Check for MultiTimeSeries + MultiTimeSeries pair
elif isinstance(ts1, MultiTimeSeries) and isinstance(ts2, MultiTimeSeries):
if is_pearson or is_av:
error_msg = f"FFT {metric} only supports univariate TimeSeries pairs"
else:
d1 = ts1.data.shape[0]
d2 = ts2.data.shape[0]
if (not ts1.discrete and not ts2.discrete and
d1 + d2 <= MAX_MTS_MTS_FFT_DIMENSIONS):
fft_type = FFT_MTS_MTS
else:
error_msg = (
f"MTS-MTS FFT requires continuous variables with "
f"d1+d2 <= {MAX_MTS_MTS_FFT_DIMENSIONS}. "
f"Got: d1={d1}, d2={d2}, discrete=({ts1.discrete},{ts2.discrete})"
)
# Check for univariate TimeSeries pairs
elif (isinstance(ts1, TimeSeries) and not isinstance(ts1, MultiTimeSeries) and
isinstance(ts2, TimeSeries) and not isinstance(ts2, MultiTimeSeries)):
is_av = (metric == "av")
if is_av:
# AV requires one binary-discrete and one continuous variable
if ts1.discrete != ts2.discrete:
discrete_ts = ts1 if ts1.discrete else ts2
if discrete_ts.is_binary:
fft_type = FFT_AV_DISCRETE
else:
error_msg = "FFT av requires binary discrete variable"
else:
error_msg = "FFT av requires one discrete and one continuous variable"
elif is_pearson:
if not ts1.discrete and not ts2.discrete:
fft_type = FFT_PEARSON_CONTINUOUS
elif ts1.discrete != ts2.discrete:
fft_type = FFT_PEARSON_DISCRETE
else:
error_msg = "FFT Pearson does not support two discrete variables"
else:
# Discrete-continuous pair
if ts1.discrete != ts2.discrete:
fft_type = FFT_DISCRETE
# Continuous-continuous pair
elif not ts1.discrete and not ts2.discrete:
fft_type = FFT_CONTINUOUS
else:
# Both discrete - use discrete-discrete FFT
fft_type = FFT_DISCRETE_DISCRETE
else:
error_msg = (
f"FFT requires univariate TimeSeries or MultiTimeSeries+TimeSeries pair. "
f"Got: ts1={type(ts1).__name__}, ts2={type(ts2).__name__}"
)
# Handle engine='fft' validation
if engine == "fft":
if fft_type is None:
raise ValueError(
f"engine='fft' requested but no FFT optimization is applicable. {error_msg}"
)
return fft_type
# engine='auto': check threshold
if fft_type is not None and count >= threshold:
return fft_type
return None
def _extract_fft_data(ts1, ts2, fft_type, ds: int):
"""Extract and prepare data for FFT computation.
Parameters
----------
ts1 : TimeSeries or MultiTimeSeries
First time series in the pair.
ts2 : TimeSeries or MultiTimeSeries
Second time series in the pair.
fft_type : str
FFT type constant (FFT_CONTINUOUS, FFT_DISCRETE, FFT_MULTIVARIATE).
ds : int
Downsampling factor.
Returns
-------
tuple
(data1, data2) ready for the corresponding FFT function.
"""
if fft_type == FFT_CONTINUOUS:
return ts1.copula_normal_data[::ds], ts2.copula_normal_data[::ds]
elif fft_type == FFT_DISCRETE:
if ts1.discrete:
return ts2.copula_normal_data[::ds], ts1.int_data[::ds]
else:
return ts1.copula_normal_data[::ds], ts2.int_data[::ds]
elif fft_type == FFT_MULTIVARIATE:
if isinstance(ts1, MultiTimeSeries):
return ts2.copula_normal_data[::ds], ts1.copula_normal_data[:, ::ds]
else:
return ts1.copula_normal_data[::ds], ts2.copula_normal_data[:, ::ds]
elif fft_type == FFT_MTS_DISCRETE:
# Handle both orientations (MTS, discrete) or (discrete, MTS)
if isinstance(ts1, MultiTimeSeries):
return ts1.copula_normal_data[:, ::ds], ts2.int_data[::ds]
else:
return ts2.copula_normal_data[:, ::ds], ts1.int_data[::ds]
elif fft_type == FFT_MTS_MTS:
return ts1.copula_normal_data[:, ::ds], ts2.copula_normal_data[:, ::ds]
elif fft_type == FFT_DISCRETE_DISCRETE:
return ts1.int_data[::ds], ts2.int_data[::ds]
elif fft_type == FFT_PEARSON_CONTINUOUS:
return ts1.data[::ds], ts2.data[::ds]
elif fft_type == FFT_PEARSON_DISCRETE:
# Treat discrete data as float for point-biserial correlation
if ts1.discrete:
return ts2.data[::ds], ts1.data[::ds].astype(float)
else:
return ts1.data[::ds], ts2.data[::ds].astype(float)
elif fft_type == FFT_AV_DISCRETE:
# AV uses RAW data (not copula-normalized): (continuous, binary)
if ts1.discrete:
return ts2.data[::ds], ts1.data[::ds].astype(float)
else:
return ts1.data[::ds], ts2.data[::ds].astype(float)
else:
raise ValueError(f"Unknown FFT type: {fft_type}")
# Dispatch table for FFT compute functions
_FFT_COMPUTE = {
FFT_CONTINUOUS: compute_mi_batch_fft,
FFT_DISCRETE: compute_mi_gd_fft,
FFT_DISCRETE_DISCRETE: compute_mi_dd_fft,
FFT_MULTIVARIATE: compute_mi_mts_fft,
FFT_MTS_MTS: compute_mi_mts_mts_fft,
FFT_MTS_DISCRETE: compute_mi_mts_discrete_fft,
FFT_PEARSON_CONTINUOUS: compute_pearson_batch_fft,
FFT_PEARSON_DISCRETE: compute_pearson_batch_fft,
FFT_AV_DISCRETE: compute_av_batch_fft,
}
def _precompute_cc_signals(ts_list, fft_type, ds):
"""Precompute (rfft_of_demeaned, std) for each unique signal.
For FFT_CONTINUOUS and FFT_PEARSON_CONTINUOUS pair types, the per-signal
quantities (demean, rfft, std) are independent of the pairing partner.
Precomputing them once per unique signal eliminates redundant work when
the same signal appears in many pairs (e.g., each neuron paired with
all features).
Parameters
----------
ts_list : list
Time series to precompute for.
fft_type : str
FFT_CONTINUOUS (uses copula_normal_data) or
FFT_PEARSON_CONTINUOUS (uses raw data).
ds : int
Downsampling factor.
Returns
-------
dict
Mapping key -> (fft_array, std_value, n_samples).
"""
precomp = {}
for ts in ts_list:
key = _get_ts_key(ts)
if key in precomp:
continue
if isinstance(ts, MultiTimeSeries):
continue
if ts.discrete:
continue
if fft_type == FFT_CONTINUOUS:
data = ts.copula_normal_data[::ds].astype(np.float64)
else: # FFT_PEARSON_CONTINUOUS
data = ts.data[::ds].astype(np.float64)
demeaned = data - data.mean()
precomp[key] = (np.fft.rfft(demeaned), np.std(demeaned, ddof=1), len(data))
return precomp
def _mi_from_precomp(fft_x, std_x, fft_y, std_y, n, bias_correction):
"""Compute MI for all n shifts from precomputed rfft and std values.
Equivalent to compute_mi_batch_fft but skips the redundant demean/rfft/std
steps by accepting precomputed values.
Parameters
----------
fft_x : ndarray
Precomputed rfft of demeaned signal x.
std_x : float
Precomputed std(ddof=1) of demeaned signal x.
fft_y : ndarray
Precomputed rfft of demeaned signal y.
std_y : float
Precomputed std(ddof=1) of demeaned signal y.
n : int
Signal length.
bias_correction : float
Precomputed Panzeri-Treves bias correction term.
Returns
-------
ndarray of shape (n,)
MI values for all shifts, in bits.
"""
if std_x < REG_VARIANCE_THRESHOLD or std_y < REG_VARIANCE_THRESHOLD:
return np.zeros(n, dtype=np.float32)
cross_corr = np.fft.irfft(fft_x * np.conj(fft_y), n=n)
r_all = cross_corr / ((n - 1) * std_x * std_y)
r_squared = np.clip(r_all ** 2, 0, 1 - 1e-10)
mi_all = -0.5 * np.log(1 - r_squared) / _LN2
mi_all = mi_all + bias_correction
return np.maximum(0, mi_all).astype(np.float32)
def _pearson_from_precomp(fft_x, std_x, fft_y, std_y, n):
"""Compute |Pearson r| for all n shifts from precomputed rfft and std.
Equivalent to compute_pearson_batch_fft but skips redundant demean/rfft/std.
Parameters
----------
fft_x : ndarray
Precomputed rfft of demeaned signal x.
std_x : float
Precomputed std(ddof=1) of demeaned signal x.
fft_y : ndarray
Precomputed rfft of demeaned signal y.
std_y : float
Precomputed std(ddof=1) of demeaned signal y.
n : int
Signal length.
Returns
-------
ndarray of shape (n,)
|Pearson r| values for all shifts.
"""
if std_x < REG_VARIANCE_THRESHOLD or std_y < REG_VARIANCE_THRESHOLD:
return np.zeros(n, dtype=np.float32)
cross_corr = np.fft.irfft(fft_x * np.conj(fft_y), n=n)
r_all = cross_corr / ((n - 1) * std_x * std_y)
return np.abs(np.clip(r_all, -1, 1)).astype(np.float32)
def _build_fft_cache_core(ts_bunch1, ts_bunch2, metric, mi_estimator, ds, engine,
pair_mask=None):
"""Core cache building logic shared by serial and parallel paths.
For continuous-continuous (cc) and Pearson continuous (pearson_cc) pairs,
precomputes per-signal quantities (demean, rfft, std) once per unique signal
instead of once per pair, giving ~2x speedup. Other pair types (discrete,
multivariate, etc.) use the standard per-pair computation path.
Parameters
----------
ts_bunch1 : list
First set of time series (subset or full).
ts_bunch2 : list
Second set of time series (always full).
metric : str
Similarity metric.
mi_estimator : str
MI estimator ('gcmi' or 'ksg').
ds : int
Downsampling factor.
engine : str
Computation engine.
pair_mask : ndarray or None, optional
Boolean/int mask of shape (len(ts_bunch1), len(ts_bunch2)).
Only pairs where pair_mask[i, j] != 0 are cached.
None means cache all pairs (backward compatible).
Returns
-------
tuple
(cache, fft_type_counts).
"""
cache = {}
fft_type_counts = {}
# Detect if any cc/pearson_cc pairs exist by checking one representative pair
precomp_type = None
for ts1 in ts_bunch1:
for ts2 in ts_bunch2:
ft = get_fft_type(ts1, ts2, metric, mi_estimator, 1, engine)
if ft in (FFT_CONTINUOUS, FFT_PEARSON_CONTINUOUS):
precomp_type = ft
break
if precomp_type is not None:
break
# Precompute per-signal quantities for eligible types
precomp = {}
bias_correction = 0.0
n_samples = None
if precomp_type is not None:
precomp = _precompute_cc_signals(
list(ts_bunch1) + list(ts_bunch2), precomp_type, ds
)
if precomp:
first_entry = next(iter(precomp.values()))
n_samples = first_entry[2]
# Precompute bias correction (only for MI, constant across all pairs)
if precomp_type == FFT_CONTINUOUS and n_samples > 2:
psi_1 = py_fast_digamma((n_samples - 1) / 2.0)
psi_2 = py_fast_digamma((n_samples - 2) / 2.0)
bias_correction = (psi_2 - psi_1) / (2.0 * _LN2)
# Build cache for pairs (optionally filtered by pair_mask)
for i, ts1 in enumerate(ts_bunch1):
key1 = _get_ts_key(ts1)
for j, ts2 in enumerate(ts_bunch2):
if pair_mask is not None and not pair_mask[i, j]:
continue
fft_type = get_fft_type(ts1, ts2, metric, mi_estimator, 1, engine)
key2 = _get_ts_key(ts2)
if fft_type is not None:
# Fast path: use precomputed per-signal quantities
if fft_type == FFT_CONTINUOUS and key1 in precomp and key2 in precomp:
fft_x, std_x, n = precomp[key1]
fft_y, std_y, _ = precomp[key2]
mi_all = _mi_from_precomp(
fft_x, std_x, fft_y, std_y, n, bias_correction
)
elif fft_type == FFT_PEARSON_CONTINUOUS and key1 in precomp and key2 in precomp:
fft_x, std_x, n = precomp[key1]
fft_y, std_y, _ = precomp[key2]
mi_all = _pearson_from_precomp(
fft_x, std_x, fft_y, std_y, n
)
else:
# Standard path for non-precomputable types
data1, data2 = _extract_fft_data(ts1, ts2, fft_type, ds)
compute_fn = _FFT_COMPUTE[fft_type]
n = len(data1) if data1.ndim == 1 else data1.shape[1]
mi_all = compute_fn(data1, data2, np.arange(n)).astype(np.float32)
cache[(key1, key2)] = FFTCacheEntry(
fft_type=fft_type, mi_all=mi_all
)
fft_type_counts[fft_type] = fft_type_counts.get(fft_type, 0) + 1
else:
cache[(key1, key2)] = None
fft_type_counts['loop'] = fft_type_counts.get('loop', 0) + 1
return cache, fft_type_counts
def _build_fft_cache(
ts_bunch1: list,
ts_bunch2: list,
metric: str,
mi_estimator: str,
ds: int,
engine: str,
n_jobs: int = 1,
pair_mask=None,
) -> dict:
"""Build FFT cache for all pairs using stable keys.
Pre-computes FFT data for all neuron-feature pairs that support FFT,
enabling reuse across delay optimization, Stage 1, and Stage 2.
When n_jobs != 1, uses parallel processing to
split ts_bunch1 across workers for faster cache building.
Uses stable keys (TimeSeries names or data hashes) instead of positional
indices, allowing cache to work correctly with both threading and loky
joblib backends, and enabling automatic cache reuse when the same
TimeSeries objects appear in different contexts (e.g., ts_with_delays subset).
Parameters
----------
ts_bunch1 : list
First set of time series (e.g., neurons).
ts_bunch2 : list
Second set of time series (e.g., features).
metric : str
Similarity metric being used.
mi_estimator : str
MI estimator ('gcmi' or 'ksg').
ds : int
Downsampling factor.
engine : str
Computation engine: 'auto', 'fft', or 'loop'.
n_jobs : int, optional
Number of parallel jobs. Default: 1 (serial).
Use -1 for all CPU cores. Parallelization is disabled
when ts_bunch1 has only one element.
pair_mask : ndarray or None, optional
Boolean/int mask of shape (len(ts_bunch1), len(ts_bunch2)).
Only pairs where pair_mask[i, j] != 0 are cached.
None means cache all pairs (backward compatible).
Returns
-------
tuple
A tuple of (cache, fft_type_counts):
- cache: Dictionary mapping (key1, key2) tuple to FFTCacheEntry or None.
Keys are stable identifiers from _get_ts_key() (names or hashes).
None indicates loop fallback should be used for that pair.
- fft_type_counts: Dictionary mapping FFT type strings to counts.
Includes 'loop' for pairs that require loop fallback.
"""
# Efficient duplicate name check: only validates when duplicate names exist
all_ts = list(ts_bunch1) + list(ts_bunch2)
# Cache (ts, name) pairs to avoid calling _get_ts_key() twice
ts_name_pairs = [(ts, _get_ts_key(ts)) for ts in all_ts] # Will raise if any unnamed
all_names = [name for _, name in ts_name_pairs]
# Fast path: if all names unique, no duplicates possible
if len(set(all_names)) != len(all_ts):
# Slow path: duplicates exist, check if they have different data
# Use data equality instead of id() for pickle stability (loky backend)
name_to_data = {}
for ts, name in ts_name_pairs:
if name not in name_to_data:
# First occurrence of this name - store reference data
name_to_data[name] = ts.data
else:
# Duplicate name - check if data matches
if not np.array_equal(name_to_data[name], ts.data):
# Same name, different data = COLLISION!
raise ValueError(
f"Cache collision: TimeSeries name '{name}' maps to different data! "
f"Same names must have identical data to share FFT cache."
)
# Delegate to parallel version if enabled and appropriate
if n_jobs != 1 and len(ts_bunch1) > 1:
return _build_fft_cache_parallel(
ts_bunch1, ts_bunch2, metric, mi_estimator, ds, engine, n_jobs,
pair_mask=pair_mask,
)
# Serial implementation
return _build_fft_cache_core(
ts_bunch1, ts_bunch2, metric, mi_estimator, ds, engine,
pair_mask=pair_mask,
)
def _build_fft_cache_worker(
ts_bunch1_subset: list,
ts_bunch2: list,
metric: str,
mi_estimator: str,
ds: int,
engine: str,
pair_mask=None,
) -> tuple:
"""Build FFT cache for a subset of ts_bunch1 (worker function).
This is a worker function used by _build_fft_cache_parallel to process
a subset of the first time series bunch in parallel.
Parameters
----------
ts_bunch1_subset : list
Subset of first time series to process.
ts_bunch2 : list
Full set of second time series (features).
metric : str
Similarity metric being used.
mi_estimator : str
MI estimator ('gcmi' or 'ksg').
ds : int
Downsampling factor.
engine : str
Computation engine: 'auto', 'fft', or 'loop'.
pair_mask : ndarray or None, optional
Mask rows for this subset. Shape (len(ts_bunch1_subset), len(ts_bunch2)).
Returns
-------
tuple
(partial_cache, partial_fft_type_counts) for this subset.
"""
return _build_fft_cache_core(
ts_bunch1_subset, ts_bunch2, metric, mi_estimator, ds, engine,
pair_mask=pair_mask,
)
def _build_fft_cache_parallel(
ts_bunch1: list,
ts_bunch2: list,
metric: str,
mi_estimator: str,
ds: int,
engine: str,
n_jobs: int = -1,
pair_mask=None,
) -> tuple:
"""Parallel version of _build_fft_cache.
Splits ts_bunch1 across workers, each builds a partial cache,
then merges all results into a single cache dictionary.
Parameters
----------
ts_bunch1 : list
First set of time series (e.g., neurons).
ts_bunch2 : list
Second set of time series (e.g., features).
metric : str
Similarity metric being used.
mi_estimator : str
MI estimator ('gcmi' or 'ksg').
ds : int
Downsampling factor.
engine : str
Computation engine: 'auto', 'fft', or 'loop'.
n_jobs : int, optional
Number of parallel jobs. Default: -1 (all CPU cores).
pair_mask : ndarray or None, optional
Boolean/int mask of shape (len(ts_bunch1), len(ts_bunch2)).
Rows are sliced per worker to match ts_bunch1 splits.
Returns
-------
tuple
(merged_cache, merged_fft_type_counts) from all workers.
"""
if n_jobs == -1:
n_jobs = min(multiprocessing.cpu_count(), len(ts_bunch1))
n_jobs_effective = min(n_jobs, len(ts_bunch1))
# Split ts_bunch1 across workers
split_inds = np.array_split(np.arange(len(ts_bunch1)), n_jobs_effective)
split_ts_bunch1 = [[ts_bunch1[i] for i in idxs] for idxs in split_inds if len(idxs) > 0]
# Slice mask rows for each worker
split_masks = (
[pair_mask[idxs] for idxs in split_inds if len(idxs) > 0]
if pair_mask is not None
else [None] * len(split_ts_bunch1)
)
# Parallel execution with backend-specific config
with _parallel_executor(n_jobs_effective) as parallel:
results = parallel(
delayed(_build_fft_cache_worker)(
subset, ts_bunch2, metric, mi_estimator, ds, engine,
pair_mask=mask_slice,
)
for subset, mask_slice in zip(split_ts_bunch1, split_masks)
)
# Merge results
merged_cache = {}
merged_counts = {}
for partial_cache, partial_counts in results:
merged_cache.update(partial_cache)
for k, v in partial_counts.items():
merged_counts[k] = merged_counts.get(k, 0) + v
return merged_cache, merged_counts