Wavelet Event Detection
This module provides wavelet-based methods for detecting events in neural time series data.
Functions
- driada.experiment.wavelet_event_detection.extract_wvt_events(traces, wvt_kwargs, show_progress=None, wavelet=None, rel_wvt_times=None, use_gpu=False)[source]
Extract calcium events from multiple traces using wavelet ridge detection.
Detects calcium transient events by finding ridges in the continuous wavelet transform (CWT) of calcium signals. Uses Generalized Morse Wavelets and ridge filtering to identify significant events.
- Parameters:
traces (ndarray) – 2D array of calcium traces (neurons x time).
wvt_kwargs (dict) –
Wavelet detection parameters:
fps : float, frame rate in Hz (default: 20)
beta : float, GMW beta parameter (default: 2)
gamma : float, GMW gamma parameter (default: 3)
sigma : float, Gaussian smoothing sigma in frames (adaptive: 0.4 * fps if not provided)
sigma_sec : float, Gaussian smoothing sigma in seconds (alternative to sigma)
eps : int, minimum spacing between events in frames (adaptive: 0.5 * fps if not provided)
eps_sec : float, minimum spacing in seconds (alternative to eps)
manual_scales : array, wavelet scales to use (adaptive: covers 0.25-2.5s at given fps if not provided)
scale_length_thr : int, minimum ridge length (default: 40)
max_scale_thr : int, max scale index threshold (default: 7)
max_ampl_thr : float, minimum ridge amplitude (default: 0.05)
max_event_dur : float, maximum event duration in seconds (default: 2.5)
max_dur_thr : int, maximum event duration in frames (adaptive: computed from max_event_dur * fps if not provided)
show_progress (bool, optional) – Whether to show progress bar. If None (default), automatically shows progress bar only when processing multiple traces (>1).
wavelet (Wavelet, optional) – Pre-computed wavelet object. If None, creates new Generalized Morse Wavelet with beta=2, gamma=3. For batch processing, pre-compute once and reuse to eliminate 0.09s overhead per call.
rel_wvt_times (array-like, optional) – Pre-computed time resolutions for each scale in manual_scales. If None, computes from wavelet. For batch processing, pre-compute once and reuse to eliminate 0.18s overhead per call (50 scales).
use_gpu (bool, default=False) – Whether to use GPU acceleration for the CWT computation via ssqueezepy. Requires PyTorch and CuPy to be installed. Note that ridge extraction remains CPU-only (Numba JIT), so speedup is limited to the CWT phase.
- Returns:
st_ev_inds (list of lists) – Start indices for each detected event per neuron.
end_ev_inds (list of lists) – End indices for each detected event per neuron.
all_ridges (list of lists) – Ridge objects containing detailed event information per neuron.
- Raises:
ValueError – If traces is not 2D or empty.
TypeError – If wvt_kwargs is not a dictionary.
Notes
The algorithm: 1. Smooths traces with Gaussian filter 2. Computes CWT using Generalized Morse Wavelets 3. Detects ridges (connected paths through scale-time plane) 4. Filters ridges based on length, amplitude, and duration criteria 5. Returns event start/end times
Ridge filtering removes noise and artifacts by requiring events to: - Persist across multiple scales (scale_length_thr) - Have sufficient amplitude (max_ampl_thr) - Have reasonable duration (max_dur_thr)
- driada.experiment.wavelet_event_detection.get_cwt_ridges(sig, wavelet=None, fps=20, scmin=150, scmax=250, all_wvt_times=None, wvt_scales=None)[source]
Extract ridges from continuous wavelet transform of a signal.
Identifies ridges (connected paths of local maxima across scales) in the wavelet transform, which correspond to transient events in the signal. Ridges are tracked by connecting maxima at adjacent scales within a time window determined by wavelet support.
- Parameters:
sig (array-like) – Input signal to analyze.
wavelet (ssqueezepy.wavelets.Wavelet, optional) – Wavelet object to use. If None, uses default from cwt().
fps (float, default=20) – Sampling rate in Hz.
scmin (int, default=150) – Starting scale index (processes scales from scmax down to scmin).
scmax (int, default=250) – Maximum scale index.
all_wvt_times (list, optional) – Pre-computed wavelet time resolutions for each scale. If None, computes them using time_resolution().
wvt_scales (array-like, optional) – Wavelet scales to use. If None, uses ‘log-piecewise’ default.
- Returns:
All detected ridges. Each Ridge object contains indices, amplitudes, scales, and time resolutions along the ridge path.
- Return type:
- Raises:
ValueError – If fps is not positive. If scmin or scmax are negative. If scmin >= scmax. If sig is not 1-dimensional.
Notes
The algorithm processes scales from coarse (scmax) to fine (scmin), extending existing ridges when possible and creating new ones for unmatched maxima. Ridges are terminated when no maxima fall within the expected time window at the next scale.
See also
RidgeRidge object storing ridge properties
- class driada.experiment.wavelet_ridge.Ridge(start_index, ampl, start_scale, wvt_time)
Container for wavelet ridge information during ridge detection.
A ridge represents a connected path through wavelet transform scales where significant coefficients are found. This class tracks the evolution of a ridge as it is being constructed.
Notes
This class uses numba’s JIT compilation when available for performance. The class is compiled with specific type annotations for optimal speed.
- __init__(**kwargs)
- Parameters:
args (Any)
kwargs (Any)
- Return type:
None
Usage Examples
Basic Event Detection
from driada.experiment import extract_wvt_events, load_demo_experiment
# Load sample experiment
exp = load_demo_experiment()
# Extract events from calcium traces
wvt_kwargs = {
'fps': exp.fps,
'sigma': 8, # smoothing parameter (frames)
'eps': 10 # minimum spacing between events (frames)
}
st_ev_inds, end_ev_inds, all_ridges = extract_wvt_events(
exp.calcium.scdata, # scaled data as numpy array
wvt_kwargs
)
# Results are lists of start/end indices per neuron
neuron_0_events = st_ev_inds[0]
print(f"Neuron 0: {len(neuron_0_events)} events detected")
Single Neuron Event Detection
from driada.experiment import events_from_trace, load_demo_experiment
from ssqueezepy import Wavelet
import numpy as np
# Load sample experiment
exp = load_demo_experiment()
# Get single neuron trace (scaled data)
trace = exp.calcium.scdata[0, :]
# Setup wavelet and scales for calcium imaging
wavelet = Wavelet(
("gmw", {"gamma": 3, "beta": 2, "centered_scale": True}),
N=8196
)
manual_scales = np.logspace(2.5, 5.5, 50, base=2) # calcium-appropriate scales
# Precompute time resolutions
from ssqueezepy import time_resolution
rel_wvt_times = [
time_resolution(wavelet, scale=sc, nondim=False, min_decay=200)
for sc in manual_scales
]
# Detect events
all_ridges, st_ev, end_ev = events_from_trace(
trace, wavelet, manual_scales, rel_wvt_times,
fps=exp.fps, sigma=8, eps=10
)
print(f"Found {len(st_ev)} events")
for i, (start, end) in enumerate(zip(st_ev[:3], end_ev[:3])):
duration_ms = (end - start) / exp.fps * 1000
print(f"Event {i}: frames {start}-{end} ({duration_ms:.0f} ms)")
Advanced Usage
Custom Event Detection Parameters
# Customize event detection parameters
from driada.experiment import extract_wvt_events, load_demo_experiment
from driada.experiment.wavelet_event_detection import WVT_EVENT_DETECTION_PARAMS
# Load sample experiment
exp = load_demo_experiment()
# Customize parameters for more/less sensitive detection
custom_params = WVT_EVENT_DETECTION_PARAMS.copy()
custom_params.update({
'fps': exp.fps,
'sigma': 4, # less smoothing for sharper events
'eps': 20, # require more spacing between events
'max_ampl_thr': 0.1, # higher threshold, fewer events
'scale_length_thr': 30 # events must persist across 30+ scales
})
# Extract events with custom parameters
st_ev_inds, end_ev_inds, all_ridges = extract_wvt_events(
exp.calcium.scdata,
custom_params
)
# Compare default vs custom
default_params = {'fps': exp.fps}
st_def, _, _ = extract_wvt_events(exp.calcium.scdata, default_params)
print(f"Default params: {sum(len(s) for s in st_def)} total events")
print(f"Custom params: {sum(len(s) for s in st_ev_inds)} total events")
Batch Processing
# Process all neurons
from driada.experiment import extract_wvt_events, load_demo_experiment
# Load sample experiment
exp = load_demo_experiment()
# Extract events for all neurons
wvt_kwargs = {
'fps': exp.fps,
'sigma': 8, # default smoothing
'eps': 10 # default spacing
}
st_ev_inds, end_ev_inds, all_ridges = extract_wvt_events(
exp.calcium.scdata, # pass all neurons at once
wvt_kwargs
)
# Analyze results
n_events_per_neuron = [len(events) for events in st_ev_inds]
print(f"Average events per neuron: {np.mean(n_events_per_neuron):.1f}")
print(f"Total events detected: {sum(n_events_per_neuron)}")
Theory
The wavelet transform provides time-frequency localization of events:
Continuous Wavelet Transform (CWT):
where \(\psi\) is the wavelet function, \(a\) is scale, and \(b\) is translation.
Ridge Detection: Ridges in the CWT correspond to dominant frequency components that persist over time, indicating events.
Event Detection Pipeline:
Compute CWT of the signal
Find ridges in scale-time space
Identify peaks along ridges
Threshold based on SNR or statistical criteria
Extract event times and properties
Wavelet Selection
Morse Wavelet (default): - Analytic wavelet with good time-frequency localization - Flexible shape parameter - Best for: General purpose event detection
Morlet Wavelet: - Gaussian-windowed complex sinusoid - Good frequency resolution - Best for: Oscillatory events
Mexican Hat (Ricker): - Second derivative of Gaussian - Simple, real-valued - Best for: Sharp, transient events