"""
Spike reconstruction module for DRIADA.
This module provides functions for reconstructing spike trains from calcium
imaging data using various methods.
"""
import numpy as np
from typing import Tuple, Dict, Any, Optional, Union, List
from scipy.ndimage import gaussian_filter1d
from scipy.signal import find_peaks
from ..information.info_base import TimeSeries, MultiTimeSeries
from .wavelet_event_detection import (
WVT_EVENT_DETECTION_PARAMS,
extract_wvt_events,
events_to_ts_array,
ridges_to_containers,
)
from ..utils.data import check_positive, check_nonnegative
[docs]
def reconstruct_spikes(
calcium: MultiTimeSeries,
method: str = "wavelet",
fps: float = 20.0,
params: Optional[Dict[str, Any]] = None,
wavelet=None,
rel_wvt_times=None,
use_gpu: bool = False,
) -> Tuple[MultiTimeSeries, Dict[str, Any]]:
"""
Reconstruct spike trains from calcium signals.
This function serves as a router to different spike reconstruction methods.
All methods operate on the scaled calcium data (values normalized to [0, 1]).
Parameters
----------
calcium : MultiTimeSeries
Calcium imaging data with each component being a neuron. Must have
valid scaled data accessible via calcium.scdata.
method : str or callable, optional
Reconstruction method. Options:
- 'wavelet': Wavelet-based detection (default)
- 'threshold': Simple threshold-based detection
- callable: Custom function with signature (calcium, fps, params) -> (spikes, metadata)
fps : float, optional
Sampling rate in frames per second. Must be positive. Default: 20.0.
params : dict, optional
Method-specific parameters. Contents depend on chosen method.
Default: empty dict.
wavelet : Wavelet, optional
Pre-computed wavelet object for batch optimization (wavelet method only).
If None, creates new. Significantly speeds up batch processing.
rel_wvt_times : array-like, optional
Pre-computed time resolutions for batch optimization (wavelet method only).
If None, computes new. Used together with wavelet parameter.
use_gpu : bool, default=False
Whether to use GPU acceleration for wavelet transform computation.
Requires PyTorch and CuPy. Ridge extraction remains CPU-only.
Returns
-------
spikes : MultiTimeSeries
Reconstructed spike trains as binary (discrete) time series.
Each component represents one neuron.
metadata : dict
Reconstruction metadata including:
- 'method': str - Method used
- 'parameters': dict - Parameters used
- Method-specific fields (see individual method docs)
Raises
------
ValueError
If method is unknown string.
If fps is not positive.
AttributeError
If calcium lacks required attributes (e.g., scdata).
TypeError
If callable method has wrong signature.
Examples
--------
>>> # Create example calcium data
>>> import numpy as np
>>> from driada.information import TimeSeries, MultiTimeSeries
>>> n_neurons, n_frames = 5, 1000
>>> raw_data = np.random.rand(n_neurons, n_frames)
>>> calcium_ts_list = [TimeSeries(raw_data[i], discrete=False) for i in range(n_neurons)]
>>> calcium_data = MultiTimeSeries(calcium_ts_list)
>>>
>>> # Using wavelet method (default)
>>> spikes, meta = reconstruct_spikes(calcium_data, fps=30.0)
>>> meta['method']
'wavelet'
>>> spikes.n_dim == n_neurons
True
>>> # Using threshold method with custom parameters
>>> params = {'threshold_std': 3.0, 'smooth_sigma': 1.5}
>>> spikes, meta = reconstruct_spikes(calcium_data, 'threshold', 30.0, params)
>>> meta['method']
'threshold'
>>> meta['parameters']['threshold_std']
3.0
>>> # Using custom reconstruction function
>>> def custom_method(calcium, fps, params):
... # Simple mock implementation
... n_neurons, n_frames = calcium.scdata.shape
... spike_data = np.zeros((n_neurons, n_frames))
... # Add a few spikes to avoid zero columns error
... for i in range(n_neurons):
... spike_data[i, i*10:(i+1)*10] = 1
... spike_ts = [TimeSeries(spike_data[i], discrete=True) for i in range(n_neurons)]
... return MultiTimeSeries(spike_ts, allow_zero_columns=True), {'custom_info': 'test'}
>>> spikes, meta = reconstruct_spikes(calcium_data, custom_method, 30.0)
>>> meta['custom_info']
'test'
Notes
-----
All built-in methods use the scaled calcium data (calcium.scdata) which
is normalized to [0, 1]. This ensures consistent behavior across different
calcium indicator types and experimental conditions."""
# Input validation
check_positive(fps=fps)
# Validate calcium has required attributes
if not hasattr(calcium, "scdata"):
raise AttributeError("calcium must have 'scdata' attribute (scaled data)")
params = params or {}
if callable(method):
# Custom method
try:
return method(calcium, fps, params)
except TypeError as e:
raise TypeError(
f"Custom method must have signature (calcium, fps, params) -> "
f"(MultiTimeSeries, dict). Error: {e}"
)
elif method == "wavelet":
return wavelet_reconstruction(calcium, fps, params, wavelet, rel_wvt_times, use_gpu)
elif method == "threshold":
return threshold_reconstruction(calcium, fps, params)
else:
raise ValueError(
f"Unknown method '{method}'. Use 'wavelet', 'threshold', " f"or provide a callable."
)
[docs]
def wavelet_reconstruction(
calcium: MultiTimeSeries,
fps: float,
params: Dict[str, Any],
wavelet=None,
rel_wvt_times=None,
use_gpu: bool = False,
) -> Tuple[MultiTimeSeries, Dict[str, Any]]:
"""
Wavelet-based spike reconstruction.
Uses continuous wavelet transform to detect calcium transients. The method
operates on scaled calcium data (normalized to [0, 1]).
Parameters
----------
calcium : MultiTimeSeries
Calcium imaging signals. Must have scdata attribute.
fps : float
Sampling rate in frames per second. Overrides default fps in parameters.
params : dict
Parameters that update WVT_EVENT_DETECTION_PARAMS defaults:
- 'sigma': int - Smoothing parameter for peak detection (frames)
- 'eps': int - Minimum spacing between consecutive events (frames)
- 'scale_length_thr': int - Min scales where ridge is present
- 'max_scale_thr': int - Index of scale with max ridge intensity
- 'max_ampl_thr': float - Max ridge intensity threshold
- 'max_dur_thr': int - Max event duration threshold
See WVT_EVENT_DETECTION_PARAMS for defaults.
wavelet : Wavelet, optional
Pre-computed wavelet object for batch optimization. If None, creates new.
rel_wvt_times : array-like, optional
Pre-computed time resolutions for batch optimization. If None, computes new.
use_gpu : bool, default False
Whether to use GPU acceleration for wavelet computation if available.
Returns
-------
spikes : MultiTimeSeries
Binary spike trains (discrete). Allow_zero_columns=True for empty neurons.
metadata : dict
Contains:
- 'method': 'wavelet'
- 'parameters': dict - All parameters used
- 'start_events': list - Event start indices per neuron
- 'end_events': list - Event end indices per neuron
- 'ridges': list - Ridge information per neuron
Raises
------
AttributeError
If calcium lacks scdata attribute.
ValueError
If calcium data is empty or invalid shape.
Notes
-----
Default parameters are defined in WVT_EVENT_DETECTION_PARAMS. The fps
parameter always overrides the default fps value.
For batch processing, pre-compute wavelet and rel_wvt_times once and reuse
across multiple calls for significant speedup (8-10x faster)."""
# Input validation
check_positive(fps=fps)
if not hasattr(calcium, "scdata"):
raise AttributeError("calcium must have 'scdata' attribute")
# Get scaled calcium data as numpy array for better spike detection
calcium_data = np.asarray(calcium.scdata) # Use scaled data
# Validate data shape
if calcium_data.ndim != 2:
raise ValueError(
f"calcium data must be 2D (neurons x time), got shape {calcium_data.shape}"
)
if calcium_data.size == 0:
raise ValueError("calcium data cannot be empty")
# Pre-compute wavelet objects if not provided (batch optimization)
if wavelet is None:
from ssqueezepy.wavelets import Wavelet, time_resolution
from .wavelet_event_detection import get_adaptive_wavelet_scales
wavelet = Wavelet(("gmw", {"gamma": 3, "beta": 2, "centered_scale": True}), N=8196)
manual_scales = get_adaptive_wavelet_scales(fps)
rel_wvt_times = [
time_resolution(wavelet, scale=sc, nondim=False, min_decay=200) for sc in manual_scales
]
# Set up wavelet parameters
wvt_kwargs = WVT_EVENT_DETECTION_PARAMS.copy()
wvt_kwargs["fps"] = fps
wvt_kwargs.update(params)
# Extract events with pre-computed wavelet objects
st_ev_inds, end_ev_inds, all_ridges = extract_wvt_events(
calcium_data, wvt_kwargs, wavelet=wavelet, rel_wvt_times=rel_wvt_times, use_gpu=use_gpu
)
# Convert to spike array
spikes_data = events_to_ts_array(calcium_data.shape[1], st_ev_inds, end_ev_inds, fps)
# Create spike MultiTimeSeries
spike_ts_list = [
TimeSeries(spikes_data[i, :], discrete=True) for i in range(spikes_data.shape[0])
]
spikes = MultiTimeSeries(spike_ts_list, allow_zero_columns=True)
# Prepare metadata
metadata = {
"method": "wavelet",
"parameters": wvt_kwargs,
"start_events": st_ev_inds,
"end_events": end_ev_inds,
"ridges": [ridges_to_containers(ridges) for ridges in all_ridges],
}
return spikes, metadata
[docs]
def threshold_reconstruction(
calcium: Union[MultiTimeSeries, TimeSeries], fps: float, params: Dict[str, Any]
) -> Tuple[MultiTimeSeries, Dict[str, Any]]:
"""
Simple threshold-based spike reconstruction.
This method detects spikes when the derivative of the calcium signal
exceeds a threshold. Operates on scaled calcium data (normalized to [0, 1]).
Parameters
----------
calcium : MultiTimeSeries or TimeSeries
Calcium signals. Must have scdata attribute.
If TimeSeries, will be treated as single neuron.
fps : float
Sampling rate in frames per second. Must be positive.
params : dict
Parameters including:
* threshold_std : float, number of STDs above mean for detection. Must be positive. Default: 2.5.
* smooth_sigma : float, gaussian smoothing sigma in frames. Must be non-negative. Default: 2.
* min_spike_interval : float, minimum interval between spikes in seconds. Must be non-negative. Default: 0.1.
Returns
-------
spikes : MultiTimeSeries
Binary spike trains (discrete). Allow_zero_columns=True for empty neurons.
metadata : dict
Contains:
- 'method': 'threshold'
- 'parameters': dict with all parameters used including fps
- 'spike_times': list of arrays - Frame indices of detected spikes per neuron
Raises
------
AttributeError
If calcium lacks scdata attribute.
ValueError
If calcium data is empty or invalid shape.
If parameters are out of valid range.
If min_spike_interval * fps < 0.5 (would result in zero minimum distance).
Notes
-----
The derivative is computed with np.diff and zero-padded at the start.
This affects the first frame which cannot have a spike detected."""
# Input validation
check_positive(fps=fps)
if not hasattr(calcium, "scdata"):
raise AttributeError("calcium must have 'scdata' attribute")
# Default parameters
threshold_std = params.get("threshold_std", 2.5)
smooth_sigma = params.get("smooth_sigma", 2)
min_spike_interval = params.get("min_spike_interval", 0.1)
# Validate parameters
check_positive(threshold_std=threshold_std)
check_nonnegative(smooth_sigma=smooth_sigma, min_spike_interval=min_spike_interval)
min_spike_frames = int(min_spike_interval * fps)
if min_spike_interval > 0 and min_spike_frames < 1:
raise ValueError(
f"min_spike_interval * fps = {min_spike_interval * fps:.2f} < 1, "
"would result in zero minimum distance between spikes"
)
calcium_data = np.asarray(calcium.scdata) # Use scaled data
# Normalize to 2D array (neurons x time)
if calcium_data.ndim == 1:
calcium_data = calcium_data.reshape(1, -1)
elif calcium_data.ndim != 2:
raise ValueError(f"calcium data must be 1D or 2D, got shape {calcium_data.shape}")
if calcium_data.size == 0:
raise ValueError("calcium data cannot be empty")
n_neurons, n_frames = calcium_data.shape
spikes_data = np.zeros_like(calcium_data)
all_spike_times = []
for i in range(n_neurons):
# Get calcium trace
trace = calcium_data[i, :]
# Smooth the signal
smoothed = gaussian_filter1d(trace, sigma=smooth_sigma)
# Compute derivative (rate of calcium increase)
diff = np.diff(smoothed)
diff = np.concatenate([[0], diff]) # Pad to maintain size
# Compute threshold
threshold = np.mean(diff) + threshold_std * np.std(diff)
# Find peaks in derivative
peaks, properties = find_peaks(diff, height=threshold, distance=min_spike_frames)
# Mark spikes
spikes_data[i, peaks] = 1
all_spike_times.append(peaks)
# Create spike MultiTimeSeries
spike_ts_list = [TimeSeries(spikes_data[i, :], discrete=True) for i in range(n_neurons)]
spikes = MultiTimeSeries(spike_ts_list, allow_zero_columns=True)
# Prepare metadata
metadata = {
"method": "threshold",
"parameters": {
"threshold_std": threshold_std,
"smooth_sigma": smooth_sigma,
"min_spike_interval": min_spike_interval,
"fps": fps,
},
"spike_times": all_spike_times,
}
return spikes, metadata