Spike Reconstruction Methods
This module provides methods for reconstructing spike trains from calcium imaging data.
Functions
- driada.experiment.spike_reconstruction.reconstruct_spikes(calcium, method='wavelet', fps=20.0, params=None, wavelet=None, rel_wvt_times=None, use_gpu=False)[source]
Reconstruct spike trains from calcium signals.
This function serves as a router to different spike reconstruction methods. All methods operate on the scaled calcium data (values normalized to [0, 1]).
- Parameters:
calcium (MultiTimeSeries) – Calcium imaging data with each component being a neuron. Must have valid scaled data accessible via calcium.scdata.
method (str or callable, optional) – Reconstruction method. Options: - ‘wavelet’: Wavelet-based detection (default) - ‘threshold’: Simple threshold-based detection - callable: Custom function with signature (calcium, fps, params) -> (spikes, metadata)
fps (float, optional) – Sampling rate in frames per second. Must be positive. Default: 20.0.
params (dict, optional) – Method-specific parameters. Contents depend on chosen method. Default: empty dict.
wavelet (Wavelet, optional) – Pre-computed wavelet object for batch optimization (wavelet method only). If None, creates new. Significantly speeds up batch processing.
rel_wvt_times (array-like, optional) – Pre-computed time resolutions for batch optimization (wavelet method only). If None, computes new. Used together with wavelet parameter.
use_gpu (bool, default=False) – Whether to use GPU acceleration for wavelet transform computation. Requires PyTorch and CuPy. Ridge extraction remains CPU-only.
- Return type:
Tuple[MultiTimeSeries,Dict[str,Any]]- Returns:
spikes (MultiTimeSeries) – Reconstructed spike trains as binary (discrete) time series. Each component represents one neuron.
metadata (dict) – Reconstruction metadata including: - ‘method’: str - Method used - ‘parameters’: dict - Parameters used - Method-specific fields (see individual method docs)
- Raises:
ValueError – If method is unknown string. If fps is not positive.
AttributeError – If calcium lacks required attributes (e.g., scdata).
TypeError – If callable method has wrong signature.
Examples
>>> # Create example calcium data >>> import numpy as np >>> from driada.information import TimeSeries, MultiTimeSeries >>> n_neurons, n_frames = 5, 1000 >>> raw_data = np.random.rand(n_neurons, n_frames) >>> calcium_ts_list = [TimeSeries(raw_data[i], discrete=False) for i in range(n_neurons)] >>> calcium_data = MultiTimeSeries(calcium_ts_list) >>> >>> # Using wavelet method (default) >>> spikes, meta = reconstruct_spikes(calcium_data, fps=30.0) >>> meta['method'] 'wavelet' >>> spikes.n_dim == n_neurons True
>>> # Using threshold method with custom parameters >>> params = {'threshold_std': 3.0, 'smooth_sigma': 1.5} >>> spikes, meta = reconstruct_spikes(calcium_data, 'threshold', 30.0, params) >>> meta['method'] 'threshold' >>> meta['parameters']['threshold_std'] 3.0
>>> # Using custom reconstruction function >>> def custom_method(calcium, fps, params): ... # Simple mock implementation ... n_neurons, n_frames = calcium.scdata.shape ... spike_data = np.zeros((n_neurons, n_frames)) ... # Add a few spikes to avoid zero columns error ... for i in range(n_neurons): ... spike_data[i, i*10:(i+1)*10] = 1 ... spike_ts = [TimeSeries(spike_data[i], discrete=True) for i in range(n_neurons)] ... return MultiTimeSeries(spike_ts, allow_zero_columns=True), {'custom_info': 'test'} >>> spikes, meta = reconstruct_spikes(calcium_data, custom_method, 30.0) >>> meta['custom_info'] 'test'
Notes
All built-in methods use the scaled calcium data (calcium.scdata) which is normalized to [0, 1]. This ensures consistent behavior across different calcium indicator types and experimental conditions.
- driada.experiment.spike_reconstruction.wavelet_reconstruction(calcium, fps, params, wavelet=None, rel_wvt_times=None, use_gpu=False)[source]
Wavelet-based spike reconstruction.
Uses continuous wavelet transform to detect calcium transients. The method operates on scaled calcium data (normalized to [0, 1]).
- Parameters:
calcium (MultiTimeSeries) – Calcium imaging signals. Must have scdata attribute.
fps (float) – Sampling rate in frames per second. Overrides default fps in parameters.
params (dict) – Parameters that update WVT_EVENT_DETECTION_PARAMS defaults: - ‘sigma’: int - Smoothing parameter for peak detection (frames) - ‘eps’: int - Minimum spacing between consecutive events (frames) - ‘scale_length_thr’: int - Min scales where ridge is present - ‘max_scale_thr’: int - Index of scale with max ridge intensity - ‘max_ampl_thr’: float - Max ridge intensity threshold - ‘max_dur_thr’: int - Max event duration threshold See WVT_EVENT_DETECTION_PARAMS for defaults.
wavelet (Wavelet, optional) – Pre-computed wavelet object for batch optimization. If None, creates new.
rel_wvt_times (array-like, optional) – Pre-computed time resolutions for batch optimization. If None, computes new.
use_gpu (bool, default False) – Whether to use GPU acceleration for wavelet computation if available.
- Return type:
Tuple[MultiTimeSeries,Dict[str,Any]]- Returns:
spikes (MultiTimeSeries) – Binary spike trains (discrete). Allow_zero_columns=True for empty neurons.
metadata (dict) – Contains: - ‘method’: ‘wavelet’ - ‘parameters’: dict - All parameters used - ‘start_events’: list - Event start indices per neuron - ‘end_events’: list - Event end indices per neuron - ‘ridges’: list - Ridge information per neuron
- Raises:
AttributeError – If calcium lacks scdata attribute.
ValueError – If calcium data is empty or invalid shape.
Notes
Default parameters are defined in WVT_EVENT_DETECTION_PARAMS. The fps parameter always overrides the default fps value.
For batch processing, pre-compute wavelet and rel_wvt_times once and reuse across multiple calls for significant speedup (8-10x faster).
- driada.experiment.spike_reconstruction.threshold_reconstruction(calcium, fps, params)[source]
Simple threshold-based spike reconstruction.
This method detects spikes when the derivative of the calcium signal exceeds a threshold. Operates on scaled calcium data (normalized to [0, 1]).
- Parameters:
calcium (MultiTimeSeries or TimeSeries) – Calcium signals. Must have scdata attribute. If TimeSeries, will be treated as single neuron.
fps (float) – Sampling rate in frames per second. Must be positive.
params (dict) –
Parameters including:
threshold_std : float, number of STDs above mean for detection. Must be positive. Default: 2.5.
smooth_sigma : float, gaussian smoothing sigma in frames. Must be non-negative. Default: 2.
min_spike_interval : float, minimum interval between spikes in seconds. Must be non-negative. Default: 0.1.
- Return type:
Tuple[MultiTimeSeries,Dict[str,Any]]- Returns:
spikes (MultiTimeSeries) – Binary spike trains (discrete). Allow_zero_columns=True for empty neurons.
metadata (dict) – Contains: - ‘method’: ‘threshold’ - ‘parameters’: dict with all parameters used including fps - ‘spike_times’: list of arrays - Frame indices of detected spikes per neuron
- Raises:
AttributeError – If calcium lacks scdata attribute.
ValueError – If calcium data is empty or invalid shape. If parameters are out of valid range. If min_spike_interval * fps < 0.5 (would result in zero minimum distance).
Notes
The derivative is computed with np.diff and zero-padded at the start. This affects the first frame which cannot have a spike detected.
Usage Examples
Basic Spike Reconstruction
from driada.experiment import reconstruct_spikes, load_demo_experiment
# Load sample experiment
exp = load_demo_experiment()
# Reconstruct spikes using default method (wavelet)
spikes = reconstruct_spikes(
exp.calcium,
fps=exp.fps,
method='wavelet'
)
# Add to experiment
exp.spikes = spikes
Wavelet-based Reconstruction
from driada.experiment import wavelet_reconstruction, load_demo_experiment
# Load sample experiment
exp = load_demo_experiment()
# Wavelet reconstruction with custom parameters
params = {
'sigma': 2, # Smoothing parameter
'eps': 3, # Min spacing between events
'scale_length_thr': 3, # Min scales for ridge
'max_scale_thr': 5 # Scale with max intensity
}
spikes, metadata = wavelet_reconstruction(
exp.calcium,
fps=exp.fps,
params=params
)
# Access reconstruction info
print(f"Detected {len(metadata['start_events'])} spike events")
Threshold-based Method
from driada.experiment import threshold_reconstruction, load_demo_experiment
# Load sample experiment
exp = load_demo_experiment()
# Simple threshold-based detection
params = {
'threshold': 3.0, # threshold in standard deviations
'min_width': 2 # minimum spike width
}
spikes, metadata = threshold_reconstruction(
exp.calcium,
fps=exp.fps,
params=params
)
Validation and Comparison
from driada.experiment import reconstruct_spikes, load_demo_experiment
# Load sample experiment
exp = load_demo_experiment()
# Compare different methods
methods = ['wavelet', 'threshold']
reconstructions = {}
for method in methods:
spikes, metadata = reconstruct_spikes(
exp.calcium,
method=method,
fps=exp.fps
)
reconstructions[method] = spikes
# Compare reconstructions using correlation
from scipy.stats import pearsonr
# Assume exp is an Experiment object already created
# exp = Experiment(...) # See Experiment docs for full parameters
corr, _ = pearsonr(
reconstructions['wavelet'].data.flatten(),
reconstructions['threshold'].data.flatten()
)
print(f"Method correlation: {corr:.3f}")
Method Selection Guide
Wavelet Method (default): - Pros: Robust to noise, captures event timing well - Cons: May miss very small events - Best for: Standard calcium imaging data
Threshold Method: - Pros: Simple, fast, interpretable - Cons: Sensitive to baseline fluctuations - Best for: High SNR data with stable baseline
Parameter Guidelines
Wavelet parameters:
wavelet: ‘morse’ (default) or ‘morlet’threshold_factor: 2.5-3.5 (higher = fewer false positives)
Threshold parameters:
threshold: ‘adaptive’ or fixed value (e.g., 2.0)sigma: 2.5-4.0 standard deviationsmin_spike_width: 1-3 frames
Output Format
All methods return a MultiTimeSeries object with:
Binary spike indicators (0 or 1)
Same dimensions as input calcium data
Preserved neuron ordering
Time alignment with calcium data