Spike Reconstruction Methods

This module provides methods for reconstructing spike trains from calcium imaging data.

Functions

driada.experiment.spike_reconstruction.reconstruct_spikes(calcium, method='wavelet', fps=20.0, params=None, wavelet=None, rel_wvt_times=None, use_gpu=False)[source]

Reconstruct spike trains from calcium signals.

This function serves as a router to different spike reconstruction methods. All methods operate on the scaled calcium data (values normalized to [0, 1]).

Parameters:
  • calcium (MultiTimeSeries) – Calcium imaging data with each component being a neuron. Must have valid scaled data accessible via calcium.scdata.

  • method (str or callable, optional) – Reconstruction method. Options: - ‘wavelet’: Wavelet-based detection (default) - ‘threshold’: Simple threshold-based detection - callable: Custom function with signature (calcium, fps, params) -> (spikes, metadata)

  • fps (float, optional) – Sampling rate in frames per second. Must be positive. Default: 20.0.

  • params (dict, optional) – Method-specific parameters. Contents depend on chosen method. Default: empty dict.

  • wavelet (Wavelet, optional) – Pre-computed wavelet object for batch optimization (wavelet method only). If None, creates new. Significantly speeds up batch processing.

  • rel_wvt_times (array-like, optional) – Pre-computed time resolutions for batch optimization (wavelet method only). If None, computes new. Used together with wavelet parameter.

  • use_gpu (bool, default=False) – Whether to use GPU acceleration for wavelet transform computation. Requires PyTorch and CuPy. Ridge extraction remains CPU-only.

Return type:

Tuple[MultiTimeSeries, Dict[str, Any]]

Returns:

  • spikes (MultiTimeSeries) – Reconstructed spike trains as binary (discrete) time series. Each component represents one neuron.

  • metadata (dict) – Reconstruction metadata including: - ‘method’: str - Method used - ‘parameters’: dict - Parameters used - Method-specific fields (see individual method docs)

Raises:
  • ValueError – If method is unknown string. If fps is not positive.

  • AttributeError – If calcium lacks required attributes (e.g., scdata).

  • TypeError – If callable method has wrong signature.

Examples

>>> # Create example calcium data
>>> import numpy as np
>>> from driada.information import TimeSeries, MultiTimeSeries
>>> n_neurons, n_frames = 5, 1000
>>> raw_data = np.random.rand(n_neurons, n_frames)
>>> calcium_ts_list = [TimeSeries(raw_data[i], discrete=False) for i in range(n_neurons)]
>>> calcium_data = MultiTimeSeries(calcium_ts_list)
>>>
>>> # Using wavelet method (default)
>>> spikes, meta = reconstruct_spikes(calcium_data, fps=30.0)
>>> meta['method']
'wavelet'
>>> spikes.n_dim == n_neurons
True
>>> # Using threshold method with custom parameters
>>> params = {'threshold_std': 3.0, 'smooth_sigma': 1.5}
>>> spikes, meta = reconstruct_spikes(calcium_data, 'threshold', 30.0, params)
>>> meta['method']
'threshold'
>>> meta['parameters']['threshold_std']
3.0
>>> # Using custom reconstruction function
>>> def custom_method(calcium, fps, params):
...     # Simple mock implementation
...     n_neurons, n_frames = calcium.scdata.shape
...     spike_data = np.zeros((n_neurons, n_frames))
...     # Add a few spikes to avoid zero columns error
...     for i in range(n_neurons):
...         spike_data[i, i*10:(i+1)*10] = 1
...     spike_ts = [TimeSeries(spike_data[i], discrete=True) for i in range(n_neurons)]
...     return MultiTimeSeries(spike_ts, allow_zero_columns=True), {'custom_info': 'test'}
>>> spikes, meta = reconstruct_spikes(calcium_data, custom_method, 30.0)
>>> meta['custom_info']
'test'

Notes

All built-in methods use the scaled calcium data (calcium.scdata) which is normalized to [0, 1]. This ensures consistent behavior across different calcium indicator types and experimental conditions.

driada.experiment.spike_reconstruction.wavelet_reconstruction(calcium, fps, params, wavelet=None, rel_wvt_times=None, use_gpu=False)[source]

Wavelet-based spike reconstruction.

Uses continuous wavelet transform to detect calcium transients. The method operates on scaled calcium data (normalized to [0, 1]).

Parameters:
  • calcium (MultiTimeSeries) – Calcium imaging signals. Must have scdata attribute.

  • fps (float) – Sampling rate in frames per second. Overrides default fps in parameters.

  • params (dict) – Parameters that update WVT_EVENT_DETECTION_PARAMS defaults: - ‘sigma’: int - Smoothing parameter for peak detection (frames) - ‘eps’: int - Minimum spacing between consecutive events (frames) - ‘scale_length_thr’: int - Min scales where ridge is present - ‘max_scale_thr’: int - Index of scale with max ridge intensity - ‘max_ampl_thr’: float - Max ridge intensity threshold - ‘max_dur_thr’: int - Max event duration threshold See WVT_EVENT_DETECTION_PARAMS for defaults.

  • wavelet (Wavelet, optional) – Pre-computed wavelet object for batch optimization. If None, creates new.

  • rel_wvt_times (array-like, optional) – Pre-computed time resolutions for batch optimization. If None, computes new.

  • use_gpu (bool, default False) – Whether to use GPU acceleration for wavelet computation if available.

Return type:

Tuple[MultiTimeSeries, Dict[str, Any]]

Returns:

  • spikes (MultiTimeSeries) – Binary spike trains (discrete). Allow_zero_columns=True for empty neurons.

  • metadata (dict) – Contains: - ‘method’: ‘wavelet’ - ‘parameters’: dict - All parameters used - ‘start_events’: list - Event start indices per neuron - ‘end_events’: list - Event end indices per neuron - ‘ridges’: list - Ridge information per neuron

Raises:

Notes

Default parameters are defined in WVT_EVENT_DETECTION_PARAMS. The fps parameter always overrides the default fps value.

For batch processing, pre-compute wavelet and rel_wvt_times once and reuse across multiple calls for significant speedup (8-10x faster).

driada.experiment.spike_reconstruction.threshold_reconstruction(calcium, fps, params)[source]

Simple threshold-based spike reconstruction.

This method detects spikes when the derivative of the calcium signal exceeds a threshold. Operates on scaled calcium data (normalized to [0, 1]).

Parameters:
  • calcium (MultiTimeSeries or TimeSeries) – Calcium signals. Must have scdata attribute. If TimeSeries, will be treated as single neuron.

  • fps (float) – Sampling rate in frames per second. Must be positive.

  • params (dict) –

    Parameters including:

    • threshold_std : float, number of STDs above mean for detection. Must be positive. Default: 2.5.

    • smooth_sigma : float, gaussian smoothing sigma in frames. Must be non-negative. Default: 2.

    • min_spike_interval : float, minimum interval between spikes in seconds. Must be non-negative. Default: 0.1.

Return type:

Tuple[MultiTimeSeries, Dict[str, Any]]

Returns:

  • spikes (MultiTimeSeries) – Binary spike trains (discrete). Allow_zero_columns=True for empty neurons.

  • metadata (dict) – Contains: - ‘method’: ‘threshold’ - ‘parameters’: dict with all parameters used including fps - ‘spike_times’: list of arrays - Frame indices of detected spikes per neuron

Raises:
  • AttributeError – If calcium lacks scdata attribute.

  • ValueError – If calcium data is empty or invalid shape. If parameters are out of valid range. If min_spike_interval * fps < 0.5 (would result in zero minimum distance).

Notes

The derivative is computed with np.diff and zero-padded at the start. This affects the first frame which cannot have a spike detected.

Usage Examples

Basic Spike Reconstruction

from driada.experiment import reconstruct_spikes, load_demo_experiment

# Load sample experiment
exp = load_demo_experiment()

# Reconstruct spikes using default method (wavelet)
spikes = reconstruct_spikes(
    exp.calcium,
    fps=exp.fps,
    method='wavelet'
)

# Add to experiment
exp.spikes = spikes

Wavelet-based Reconstruction

from driada.experiment import wavelet_reconstruction, load_demo_experiment

# Load sample experiment
exp = load_demo_experiment()

# Wavelet reconstruction with custom parameters
params = {
    'sigma': 2,            # Smoothing parameter
    'eps': 3,              # Min spacing between events
    'scale_length_thr': 3, # Min scales for ridge
    'max_scale_thr': 5     # Scale with max intensity
}

spikes, metadata = wavelet_reconstruction(
    exp.calcium,
    fps=exp.fps,
    params=params
)

# Access reconstruction info
print(f"Detected {len(metadata['start_events'])} spike events")

Threshold-based Method

from driada.experiment import threshold_reconstruction, load_demo_experiment

# Load sample experiment
exp = load_demo_experiment()

# Simple threshold-based detection
params = {
    'threshold': 3.0,      # threshold in standard deviations
    'min_width': 2         # minimum spike width
}

spikes, metadata = threshold_reconstruction(
    exp.calcium,
    fps=exp.fps,
    params=params
)

Validation and Comparison

from driada.experiment import reconstruct_spikes, load_demo_experiment

# Load sample experiment
exp = load_demo_experiment()

# Compare different methods
methods = ['wavelet', 'threshold']
reconstructions = {}

for method in methods:
    spikes, metadata = reconstruct_spikes(
        exp.calcium,
        method=method,
        fps=exp.fps
    )
    reconstructions[method] = spikes

# Compare reconstructions using correlation
from scipy.stats import pearsonr

# Assume exp is an Experiment object already created
# exp = Experiment(...) # See Experiment docs for full parameters
corr, _ = pearsonr(
    reconstructions['wavelet'].data.flatten(),
    reconstructions['threshold'].data.flatten()
)
print(f"Method correlation: {corr:.3f}")

Method Selection Guide

Wavelet Method (default): - Pros: Robust to noise, captures event timing well - Cons: May miss very small events - Best for: Standard calcium imaging data

Threshold Method: - Pros: Simple, fast, interpretable - Cons: Sensitive to baseline fluctuations - Best for: High SNR data with stable baseline

Parameter Guidelines

Wavelet parameters:

  • wavelet: ‘morse’ (default) or ‘morlet’

  • threshold_factor: 2.5-3.5 (higher = fewer false positives)

Threshold parameters:

  • threshold: ‘adaptive’ or fixed value (e.g., 2.0)

  • sigma: 2.5-4.0 standard deviations

  • min_spike_width: 1-3 frames

Output Format

All methods return a MultiTimeSeries object with:

  • Binary spike indicators (0 or 1)

  • Same dimensions as input calcium data

  • Preserved neuron ordering

  • Time alignment with calcium data