Source code for driada.experiment.neuron

import numpy as np
import logging
from scipy.stats import median_abs_deviation
from scipy.optimize import minimize
from scipy.signal import savgol_filter

from ..information.info_base import TimeSeries
from ..utils.data import check_positive, check_nonnegative
from .wavelet_event_detection import extract_wvt_events, events_to_ts_array

# Import refactored modules
from .calcium_kinetics import (
    spike_form,
    get_restored_calcium,
    ca_mse_error,
    DEFAULT_T_RISE,
    DEFAULT_T_OFF,
)
from .signal_preprocessing import (
    calcium_preprocessing,
    MAD_SCALE_FACTOR,
)
from .event_detection import (
    SimpleEvent,
    extract_event_amplitudes,
    deconvolve_given_event_times,
    compute_kernel_peak_offset,
    estimate_onset_times,
    amplitudes_to_point_events,
    _calculate_event_r2,
    DEFAULT_FPS,
    DEFAULT_MIN_BEHAVIOUR_TIME,
    CA_SHIFT_N_TOFF,
    BASELINE_WINDOW_SEC,
    MAX_FRAMES_FORWARD_SEC,
    MAX_FRAMES_BACK_SEC,
    MIN_DECAY_FRAMES_SEC,
    MIN_RISE_FRAMES_SEC,
    MIN_VALID_POINTS_SEC,
    SAVGOL_WINDOW_SEC,
    BASELINE_OFFSET_SEC,
)


[docs] class Neuron: """ Neural calcium and spike data processing. This class handles calcium imaging time series data and spike trains, providing methods for preprocessing, spike-calcium deconvolution, and various shuffling techniques for statistical testing. Parameters ---------- cell_id : int or str Unique identifier for the neuron ca : array-like Calcium imaging time series data. Must be 1D array of finite values. sp : array-like, optional Spike train data (binary array where 1 indicates spike). If None, spike-related methods will not be available. default_t_rise : float, default=0.25 Rise time constant in seconds for calcium transients. Must be positive. default_t_off : float, default=2.0 Decay time constant in seconds for calcium transients. Must be positive. fps : float, default=20.0 Sampling rate in frames per second. Must be positive. fit_individual_t_off : bool, default=False Whether to fit decay time for this specific neuron using optimization. Attributes ---------- cell_id : int or str The neuron identifier t_rise : float Rise time constant in frames (optimized, set by get_kinetics()) t_off : float Decay time constant in frames (optimized, set by get_kinetics()) fps : float Sampling rate in frames per second ca_ts : TimeSeries Preprocessed calcium time series sp_ts : TimeSeries or None Spike train time series (if spikes provided) Notes ----- The class assumes spike data is binary (0 or 1 values). Non-binary spike data may produce incorrect results in spike counting.""" # Static method delegations for backward compatibility # These maintain the Neuron.method() API while using refactored modules spike_form = staticmethod(spike_form) get_restored_calcium = staticmethod(get_restored_calcium) ca_mse_error = staticmethod(ca_mse_error) calcium_preprocessing = staticmethod(calcium_preprocessing) extract_event_amplitudes = staticmethod(extract_event_amplitudes) deconvolve_given_event_times = staticmethod(deconvolve_given_event_times) compute_kernel_peak_offset = staticmethod(compute_kernel_peak_offset) estimate_onset_times = staticmethod(estimate_onset_times) amplitudes_to_point_events = staticmethod(amplitudes_to_point_events) _calculate_event_r2 = staticmethod(_calculate_event_r2) def _get_t_rise(self): """Optimized rise time constant (frames). Returns None if not yet optimized. Setter validates and clears cached metrics. """ return self._t_rise def _set_t_rise(self, value): """Set rise time constant with validation and cache clearing. Parameters ---------- value : float or None Rise time in frames (must be positive if not None). """ if value is not None: check_positive(value=value) self._t_rise = value self._clear_cached_metrics() t_rise = property(_get_t_rise, _set_t_rise) def _get_t_off(self): """Optimized decay time constant (frames). Returns None if not yet optimized. Setter validates and clears cached metrics. """ return self._t_off def _set_t_off(self, value): """Set decay time constant with validation and cache clearing. Parameters ---------- value : float or None Decay time in frames (must be positive and > t_rise if not None). """ if value is not None: check_positive(value=value) if self._t_rise is not None and value <= self._t_rise: raise ValueError( f"""t_off ({value:.2f}) must be greater than t_rise ({self._t_rise:.2f})""" ) self._t_off = value self._clear_cached_metrics() t_off = property(_get_t_off, _set_t_off)
[docs] def __init__( self, cell_id, ca, sp, default_t_rise=DEFAULT_T_RISE, default_t_off=DEFAULT_T_OFF, fps=DEFAULT_FPS, fit_individual_t_off=False, optimize_kinetics=False, asp=None, wvt_ridges=None, seed=None, reconstructed=None, metrics=None, ): """Initialize Neuron object with calcium and spike data. Parameters ---------- cell_id : str or int Unique identifier for this neuron. ca : array-like Calcium fluorescence signal. Must be 1D. sp : array-like or None Binary spike train. If provided, must have same length as ca. default_t_rise : float, optional Default rise time constant in seconds. Default is DEFAULT_T_RISE. default_t_off : float, optional Default decay time constant in seconds. Default is DEFAULT_T_OFF. fps : float, optional Sampling rate in Hz. Must be positive. Default is DEFAULT_FPS. fit_individual_t_off : bool, optional **DEPRECATED**: Use `optimize_kinetics` instead. If True, fit individual decay time using old method. Default is False. optimize_kinetics : bool or str, optional If True or 'direct', optimize kinetics parameters (t_rise, t_off) using direct measurement from event shapes. If False, use default parameters. Default is False. Requires either `asp` parameter or prior call to reconstruct_spikes(). asp : array-like or None, optional Pre-computed amplitude spikes (from prior reconstruction). If provided, enables kinetics optimization without re-running reconstruction. wvt_ridges : list or None, optional Pre-computed wavelet ridges (from prior reconstruction). Used for kinetics optimization to provide event boundaries. seed : int, optional Random seed for preprocessing reproducibility. reconstructed : array-like or None, optional Pre-computed reconstructed calcium signal. If provided, must have same length as ca. Skips reconstruction step and directly populates the internal reconstructed trace. metrics : dict or None, optional Pre-computed quality metrics dictionary. Recognized keys include 't_rise' and 't_off' (kinetics in seconds, converted to frames), 'r2_score' and 'snr_recon' (reconstruction quality), 'event_r2_score' and 'event_snr' (event-level quality), and 'noise_level' (noise amplitude). Unrecognized keys are stored but otherwise ignored. Raises ------ ValueError If ca is empty, fps is not positive, or array lengths don't match. If optimize_kinetics=True but no spike data available. TypeError If ca cannot be converted to numeric array. Notes ----- The shuffle mask excludes CA_SHIFT_N_TOFF * t_off frames from each end to prevent near-zero circular shifts from contaminating the null distribution. **New workflow** (recommended): >>> neuron = Neuron(cell_id=0, ca=calcium, sp=None, fps=20) >>> neuron.reconstruct_spikes(method='wavelet') >>> neuron.get_kinetics() # Optimize after reconstruction **Or pass pre-computed data**: >>> neuron = Neuron(cell_id=0, ca=calcium, sp=None, fps=20, ... asp=amp_spikes, wvt_ridges=ridges, ... optimize_kinetics=True) """ ca = np.asarray(ca) if ca.size == 0: raise ValueError("Calcium signal cannot be empty") check_positive(fps=fps, default_t_rise=default_t_rise, default_t_off=default_t_off) self.cell_id = cell_id self.ca = TimeSeries(Neuron.calcium_preprocessing(ca, seed=seed), discrete=False, name=f"neuron_{self.cell_id}_ca") from sklearn.preprocessing import MinMaxScaler self.ca_scaler = MinMaxScaler() self.ca_scaler.fit(self.ca.data.reshape(-1, 1)) if sp is None: self.sp = None else: sp = np.asarray(sp) if len(sp) != len(ca): raise ValueError( f"""Spike train length {len(sp)} must match calcium length {len(ca)}""" ) self.sp = TimeSeries(sp.astype(int), discrete=True, name=f"neuron_{self.cell_id}_sp") self.n_frames = len(self.ca.data) self.fps = fps self.sp_count = np.sum(self.sp.data) if self.sp is not None else 0 self.events = None if asp is not None: asp = np.asarray(asp) if len(asp) != len(ca): raise ValueError(f"""asp length {len(asp)} must match calcium length {len(ca)}""") self.asp = TimeSeries(asp, discrete=False, name=f"neuron_{self.cell_id}_asp") else: self.asp = None self.wvt_ridges = wvt_ridges self.metrics = metrics self._t_rise = None self._t_off = None self.noise_ampl = None self.mad = None self.snr = None self.wavelet_snr = None self._kinetics_info = None self.reconstruction_r2 = None self.snr_reconstruction = None self.mae = None self.event_count = None self.event_r2 = None self.event_snr = None self._reconstructed = None self._reconstructed_scaled = None self._has_reconstructed = False # Pre-populate cached attributes from metrics if available if metrics is not None: # Kinetics: convert seconds to frames if 't_rise' in metrics and metrics['t_rise'] is not None: t_rise_sec = metrics['t_rise'] if t_rise_sec > 0: self._t_rise = t_rise_sec * fps if 't_off' in metrics and metrics['t_off'] is not None: t_off_sec = metrics['t_off'] t_rise_sec = metrics.get('t_rise', 0) if t_off_sec > t_rise_sec and t_off_sec > 0: self._t_off = t_off_sec * fps # Quality metrics: populate cached attributes that have computation methods if 'r2_score' in metrics and metrics['r2_score'] is not None: self.reconstruction_r2 = float(metrics['r2_score']) if 'snr_recon' in metrics and metrics['snr_recon'] is not None: self.snr_reconstruction = float(metrics['snr_recon']) if 'event_r2_score' in metrics and metrics['event_r2_score'] is not None: self.event_r2 = float(metrics['event_r2_score']) if 'event_snr' in metrics and metrics['event_snr'] is not None: self.event_snr = float(metrics['event_snr']) if 'noise_level' in metrics and metrics['noise_level'] is not None: self.noise_ampl = float(metrics['noise_level']) # Handle pre-computed reconstruction if provided if reconstructed is not None: reconstructed = np.asarray(reconstructed) if len(reconstructed) != len(ca): raise ValueError( f"reconstructed length {len(reconstructed)} must match calcium length {len(ca)}" ) self._reconstructed = TimeSeries( reconstructed, discrete=False, name=f"neuron_{self.cell_id}_reconstructed" ) self._has_reconstructed = True if fps is None: fps = DEFAULT_FPS if default_t_rise is None: default_t_rise = DEFAULT_T_RISE if default_t_off is None: default_t_off = DEFAULT_T_OFF self.default_t_off = default_t_off * fps self.default_t_rise = default_t_rise * fps if fit_individual_t_off and optimize_kinetics: raise ValueError("Cannot specify both fit_individual_t_off and optimize_kinetics") if fit_individual_t_off: import warnings warnings.warn( "fit_individual_t_off is deprecated and will be removed in v1.0. " "Use optimize_kinetics=True instead.", DeprecationWarning, stacklevel=2, ) if self.sp is not None or self.asp is not None: (self.t_off, self.noise_ampl) = self._fit_t_off() t_off = self.t_off else: t_off = self.default_t_off elif optimize_kinetics: method = "direct" if optimize_kinetics is True else optimize_kinetics self.get_kinetics(method, fps) t_off = self.t_off if self.t_off is not None else self.default_t_off else: t_off = self.default_t_off self.ca.shuffle_mask = np.ones(self.n_frames, dtype=bool) # Exclude shifts within CA_SHIFT_N_TOFF multiples of t_off min_shift = int(t_off * CA_SHIFT_N_TOFF) self.ca.shuffle_mask[:min_shift] = False self.ca.shuffle_mask[self.n_frames - min_shift :] = False
[docs] def reconstruct_spikes( self, method="wavelet", iterative=True, n_iter=3, min_events_threshold=2, adaptive_thresholds=False, amplitude_method="deconvolution", show_progress=False, create_event_regions=False, event_mask_expansion_sec=5.0, wavelet=None, rel_wvt_times=None, use_gpu=False, **kwargs, ): """Reconstruct spikes from calcium signal. Reconstructs discrete spike events from continuous calcium fluorescence traces using wavelet or threshold-based methods. Parameters ---------- method : str, optional Reconstruction method: 'wavelet' or 'threshold'. Default is 'wavelet'. iterative : bool, optional Use iterative wavelet detection with residual analysis (wavelet method only). Detects events in residuals across multiple iterations to handle overlapping events and improve detection of smaller events. Default is True. n_iter : int, optional Number of iterations (only if iterative=True). Default is 3. min_events_threshold : int, optional Stop iterating if fewer events detected (only if iterative=True). Default is 2. adaptive_thresholds : bool, optional Progressively relax detection thresholds across iterations (only if iterative=True). Default is False. amplitude_method : str, optional Method for extracting event amplitudes: 'peak' or 'deconvolution'. - 'peak': Peak-based extraction with baseline subtraction (backward compatible) - 'deconvolution': Non-negative least squares deconvolution (default, optimal for overlapping events) Default is 'deconvolution'. show_progress : bool, optional Whether to show progress bar during wavelet detection. Default is False (no progress bar for single neuron processing). create_event_regions : bool, optional If True, creates self.events (binary rectangular regions marking event durations). This is legacy behavior - modern code should use self.asp (amplitude spikes) instead. Only needed for backward compatibility or specific visualization needs. Default is False to avoid unnecessary computation and warnings. event_mask_expansion_sec : float, optional Time in seconds to expand event mask around detected events for NNLS deconvolution. The mask is expanded by ±event_mask_expansion_sec to cover the full calcium transient (rise + decay). Larger values include more of the decay but also more baseline noise. Default is 5.0 seconds (optimal balance for GCaMP6s with t_off ~2s). wavelet : Wavelet, optional Pre-computed wavelet object for batch processing optimization. If None, will be created by extract_wvt_events(). Default is None. rel_wvt_times : array-like, optional Pre-computed time resolutions for batch processing optimization. If None, will be computed by extract_wvt_events(). Default is None. use_gpu : bool, default=False Whether to use GPU acceleration for wavelet transform computation. Requires PyTorch and CuPy. Ridge extraction remains CPU-only. **kwargs Additional parameters depend on method: For 'wavelet': * fps : float, optional Sampling rate in Hz. Default is DEFAULT_FPS (20.0 Hz). * min_event_dur : float, optional Minimum event duration in seconds. Default is 0.5. * max_event_dur : float, optional Maximum event duration in seconds. Default is 2.5. For 'threshold': * threshold_std : float, optional Number of standard deviations above mean. Default is 2.5. * smooth_sigma : float, optional Gaussian smoothing sigma in frames. Default is 2. * min_spike_interval : float, optional Minimum interval between spikes in seconds. Default is 0.1. Returns ------- ndarray or None If create_event_regions=True: Binary event regions with shape (n_frames,). Values are 0 or 1. If create_event_regions=False: None (no binary regions created). Always populates self.asp (amplitude spikes) and self.sp (binary spikes) attributes. Optionally populates self.events (binary regions) if create_event_regions=True. Raises ------ NotImplementedError If method is not 'wavelet' or 'threshold'. ValueError If parameters are invalid for the chosen method. Notes ----- The wavelet method uses continuous wavelet transform to detect calcium transient events. The threshold method uses derivative-based spike detection with Gaussian smoothing. Iterative detection (default) performs multiple detection passes on residuals, removing detected events and searching for additional events in the remaining signal. This approach significantly improves detection of overlapping and smaller events compared to single-pass detection. For single-pass detection (backward compatible), set iterative=False.""" # Warn if re-running reconstruction without optimized kinetics # (skip warning on first reconstruction - user needs events before optimizing) if self._has_reconstructed and (self.t_rise is None or self.t_off is None): import warnings fps_for_warning = kwargs.get("fps", self.fps if self.fps is not None else DEFAULT_FPS) warnings.warn( f"Neuron {self.cell_id}: Re-running reconstruction with default kinetics. " f"Consider optimizing first: neuron.optimize_kinetics(method='direct', fps={fps_for_warning})", UserWarning, ) self._has_reconstructed = True if method == "wavelet": fps = kwargs.get("fps", self.fps) min_event_dur = kwargs.get("min_event_dur", 0.5) max_event_dur = kwargs.get("max_event_dur", 2.5) check_positive(fps=fps, min_event_dur=min_event_dur, max_event_dur=max_event_dur) if min_event_dur >= max_event_dur: raise ValueError( f"""min_event_dur ({min_event_dur}) must be less than max_event_dur ({max_event_dur})""" ) wvt_kwargs = { "fps": fps, "min_event_dur": min_event_dur, "max_event_dur": max_event_dur, "scale_length_thr": kwargs.get("scale_length_thr", 15), "max_scale_thr": kwargs.get("max_scale_thr", 6), "max_ampl_thr": kwargs.get("max_ampl_thr", 0.04), "sigma": kwargs.get("sigma", 8), } ca_data = self.ca.scdata.reshape(1, -1) if iterative: # Wavelet iterative: detect events in residuals across multiple iterations check_positive(n_iter=n_iter, min_events_threshold=min_events_threshold) all_st_inds_list = [] all_end_inds_list = [] all_ridges_list = [] current_signal = self.ca.scdata.copy() # Prepare iteration-specific kwargs (adaptive thresholds if enabled) if adaptive_thresholds: base_min_dur = min_event_dur base_max_dur = max_event_dur iter_kwargs = [] for i in range(n_iter): relax_factor = 1 - i * 0.2 iter_kw = wvt_kwargs.copy() # Start with all wavelet params iter_kw["min_event_dur"] = max(base_min_dur * relax_factor, 0.1) iter_kw["max_event_dur"] = base_max_dur iter_kwargs.append(iter_kw) else: iter_kwargs = [wvt_kwargs.copy() for _ in range(n_iter)] t_rise = self.t_rise if self.t_rise is not None else self.default_t_rise t_off = self.t_off if self.t_off is not None else self.default_t_off # Iterative detection loop for iter_idx in range(n_iter): current_signal_2d = current_signal.reshape(1, -1) (st_ev_inds, end_ev_inds, filtered_ridges) = extract_wvt_events( current_signal_2d, iter_kwargs[iter_idx], show_progress=show_progress, wavelet=wavelet, rel_wvt_times=rel_wvt_times, use_gpu=use_gpu, ) st_inds = st_ev_inds[0] if len(st_ev_inds) > 0 else [] end_inds = end_ev_inds[0] if len(end_ev_inds) > 0 else [] ridges = filtered_ridges[0] if len(filtered_ridges) > 0 else [] if len(st_inds) >= min_events_threshold: all_st_inds_list.extend(st_inds) all_end_inds_list.extend(end_inds) all_ridges_list.extend(ridges) # Compute residual for next iteration current_signal = self._compute_residual( st_inds, end_inds, current_signal, t_rise, t_off, fps ) # After ALL iterations: finalize with ALL collected events self.wvt_ridges = all_ridges_list return self._finalize_detection( all_st_inds_list, all_end_inds_list, t_rise, t_off, fps, amplitude_method, event_mask_expansion_sec, create_event_regions, ) # Wavelet non-iterative: single-pass detection (st_ev_inds, end_ev_inds, filtered_ridges) = extract_wvt_events( ca_data, wvt_kwargs, show_progress=show_progress, wavelet=wavelet, rel_wvt_times=rel_wvt_times, use_gpu=use_gpu, ) self.wvt_ridges = filtered_ridges[0] if len(filtered_ridges) > 0 else [] st_inds = st_ev_inds[0] if len(st_ev_inds) > 0 else [] end_inds = end_ev_inds[0] if len(end_ev_inds) > 0 else [] t_rise = self.t_rise if self.t_rise is not None else self.default_t_rise t_off = self.t_off if self.t_off is not None else self.default_t_off return self._finalize_detection( st_inds, end_inds, t_rise, t_off, fps, amplitude_method, event_mask_expansion_sec, create_event_regions, ) if method == "threshold": # Threshold-based event detection with full deconvolution pipeline # Prepare parameters fps = kwargs.get("fps", self.fps if self.fps is not None else DEFAULT_FPS) threshold = kwargs.get("threshold", None) n_mad = kwargs.get("n_mad", 4.0) min_duration_frames = kwargs.get("min_duration_frames", 3) merge_gap_frames = kwargs.get("merge_gap_frames", 2) use_scaled = kwargs.get("use_scaled", True) event_mask_expansion_sec = kwargs.get("event_mask_expansion_sec", 2.0) # Get kinetics (use optimized if available, otherwise defaults) t_rise = self.t_rise if self.t_rise is not None else self.default_t_rise t_off = self.t_off if self.t_off is not None else self.default_t_off if iterative: # Threshold iterative: detect events in residuals across multiple iterations check_positive(n_iter=n_iter, min_events_threshold=min_events_threshold) all_st_inds_list = [] all_end_inds_list = [] all_events_list = [] original_signal = self.ca.scdata.copy() if use_scaled else self.ca.data.copy() current_signal = original_signal.copy() # Compute statistics ONCE from original signal (not residuals!) # This prevents finding spurious events in pure noise, where adaptive # threshold would keep finding tail exceedances in each iteration. if threshold is None: signal_median = np.median(original_signal) signal_mad = median_abs_deviation(original_signal, scale="normal") # Pre-compute thresholds for each iteration based on ORIGINAL statistics if adaptive_thresholds: # Progressively relax threshold: 4.0 -> 3.2 -> 2.4 -> ... MADs iter_n_mads = [max(n_mad * (1 - i * 0.2), 1.0) for i in range(n_iter)] else: iter_n_mads = [n_mad] * n_iter iter_thresholds = [signal_median + nm * signal_mad for nm in iter_n_mads] else: # User provided explicit threshold - use it for all iterations iter_thresholds = [threshold] * n_iter # Iterative detection loop for iter_idx in range(n_iter): detected_events = self.detect_events_threshold( threshold=iter_thresholds[iter_idx], min_duration_frames=min_duration_frames, merge_gap_frames=merge_gap_frames, use_scaled=False, # We pass custom signal signal=current_signal, ) st_inds = [int(e.start) for e in detected_events] end_inds = [int(e.end) for e in detected_events] if len(st_inds) >= min_events_threshold: all_st_inds_list.extend(st_inds) all_end_inds_list.extend(end_inds) all_events_list.extend(detected_events) # Compute residual for next iteration current_signal = self._compute_residual( st_inds, end_inds, current_signal, t_rise, t_off, fps ) # After ALL iterations: finalize with ALL collected events self.threshold_events = all_events_list return self._finalize_detection( all_st_inds_list, all_end_inds_list, t_rise, t_off, fps, amplitude_method, event_mask_expansion_sec, create_event_regions, ) # Threshold non-iterative: single-pass detection detected_events = self.detect_events_threshold( threshold=threshold, n_mad=n_mad, min_duration_frames=min_duration_frames, merge_gap_frames=merge_gap_frames, use_scaled=use_scaled, ) self.threshold_events = detected_events st_inds = [int(e.start) for e in detected_events] end_inds = [int(e.end) for e in detected_events] return self._finalize_detection( st_inds, end_inds, t_rise, t_off, fps, amplitude_method, event_mask_expansion_sec, create_event_regions, ) raise NotImplementedError( f"""Method \'{method}\' not implemented. Available methods: \'wavelet\', \'threshold\'""" )
def _clear_cached_metrics(self): """Clear all cached quality metrics and reconstructions. Should be called whenever ASP data or kinetics parameters change, as these changes invalidate all cached metrics that depend on reconstruction quality. Cached metrics (cleared): - self._reconstructed: Cached reconstructed calcium signal - self._reconstructed_scaled: Cached scaled reconstruction - self.reconstruction_r2: Cached default R² metric (not event-only) - self.snr_reconstruction: Cached reconstruction SNR - self.mae: Cached Mean Absolute Error - self.event_count: Cached number of detected events - self.noise_ampl: Cached RMSE (noise amplitude) - self.mad: Cached Median Absolute Deviation - self.snr: Cached Signal-to-Noise Ratio Not cached (computed on-demand, parameter-dependent): - get_nmae(n_mad): Depends on n_mad parameter - get_nrmse(n_mad): Depends on n_mad parameter - get_event_snr(n_mad): Depends on n_mad parameter - get_baseline_noise_std(n_mad): Depends on n_mad parameter - get_reconstruction_r2(event_only=True, n_mad): With event_only flag Notes ----- Kinetics info (_kinetics_info) is NOT cleared as it remains valid until a new optimization is performed. """ self._reconstructed = None self._reconstructed_scaled = None self.reconstruction_r2 = None self.snr_reconstruction = None self.mae = None self.event_count = None self.noise_ampl = None self.mad = None self.snr = None def _compute_scaled_reconstruction(self): """Compute and cache scaled reconstruction after ASP is updated. Called automatically after reconstruct_spikes() completes to ensure scaled reconstruction is available for information theory analyses. Uses current kinetics (optimized or default). """ if self.asp is None: return None t_rise = self.t_rise if self.t_rise is not None else self.default_t_rise t_off = self.t_off if self.t_off is not None else self.default_t_off ca_recon = Neuron.get_restored_calcium(self.asp.data, t_rise, t_off) self._reconstructed = TimeSeries(ca_recon, discrete=False, name=f"neuron_{self.cell_id}_reconstructed") self._reconstructed_scaled = self.ca_scaler.transform(ca_recon.reshape(-1, 1)).reshape(-1) def _extract_amplitudes( self, st_inds, end_inds, t_rise, t_off, fps, amplitude_method, event_mask_expansion_sec ): """Extract event amplitudes using NNLS deconvolution or peak method. Parameters ---------- st_inds : list of int Event start frame indices. end_inds : list of int Event end frame indices. t_rise : float Rise time in frames. t_off : float Decay time in frames. fps : float Sampling rate in Hz. amplitude_method : str 'deconvolution' for NNLS or 'peak' for peak-based extraction. event_mask_expansion_sec : float Time in seconds to expand event mask around detected events. Returns ------- list Extracted amplitudes for each event. """ if len(st_inds) == 0: return [] if amplitude_method == "deconvolution": # Estimate true onset times from detection boundaries # Detectors find rising/peak regions, not true spike onsets onset_times = Neuron.estimate_onset_times( self.ca.data, st_inds, end_inds, t_rise, t_off ) # Create expanded event mask for NNLS # Expand from onset through decay tail (5 time constants = 99.3% of decay) event_mask = np.zeros(self.n_frames, dtype=bool) mask_expansion_frames = int(event_mask_expansion_sec * fps) decay_tail_frames = int(5 * t_off) for onset, end in zip(onset_times, end_inds): expanded_start = max(0, onset - mask_expansion_frames) expanded_end = min( self.n_frames, max(int(end), onset + decay_tail_frames) + mask_expansion_frames ) event_mask[expanded_start:expanded_end] = True return Neuron.deconvolve_given_event_times( self.ca.data, onset_times, t_rise, t_off, event_mask=event_mask ) elif amplitude_method == "peak": baseline_window = int(BASELINE_WINDOW_SEC * fps) return Neuron.extract_event_amplitudes( self.ca.data, st_inds, end_inds, baseline_window=baseline_window, already_dff=True ) else: raise ValueError(f"Unknown amplitude_method: {amplitude_method}") def _create_asp_sp(self, st_inds, end_inds, amplitudes, t_rise, t_off, fps): """Create ASP (amplitude spikes) and SP (binary spikes) TimeSeries. Parameters ---------- st_inds : list of int Event start frame indices. end_inds : list of int Event end frame indices. amplitudes : list Extracted amplitudes for each event. t_rise : float Rise time in frames. t_off : float Decay time in frames. fps : float Sampling rate in Hz. Returns ------- ndarray ASP array (amplitude spikes). """ if len(st_inds) == 0 or len(amplitudes) == 0: self.asp = TimeSeries(np.zeros(self.n_frames), discrete=False, name=f"neuron_{self.cell_id}_asp") self.sp = TimeSeries(np.zeros(self.n_frames, dtype=int), discrete=True, name=f"neuron_{self.cell_id}_sp") self.sp_count = 0 return np.zeros(self.n_frames) asp = Neuron.amplitudes_to_point_events( self.n_frames, self.ca.data, st_inds, end_inds, amplitudes, placement="onset", t_rise_frames=t_rise, t_off_frames=t_off, fps=fps, ) sp = (asp > 0).astype(int) self.asp = TimeSeries(asp, discrete=False, name=f"neuron_{self.cell_id}_asp") self.sp = TimeSeries(sp, discrete=True, name=f"neuron_{self.cell_id}_sp") self.sp_count = int(np.sum(sp)) return asp def _create_event_regions(self, st_inds, end_inds, fps, create_event_regions): """Create binary event regions TimeSeries. Parameters ---------- st_inds : list of int Event start frame indices. end_inds : list of int Event end frame indices. fps : float Sampling rate in Hz. create_event_regions : bool Whether to create event regions. Returns ------- ndarray or None Event regions array if create_event_regions=True, else None. """ if not create_event_regions: self.events = None return None if len(st_inds) > 0: st_ev_inds_2d = [list(st_inds)] end_ev_inds_2d = [list(end_inds)] events = events_to_ts_array(self.n_frames, st_ev_inds_2d, end_ev_inds_2d, fps) self.events = TimeSeries(events.flatten(), discrete=True, name=f"neuron_{self.cell_id}_events") return events.flatten() else: events = np.zeros(self.n_frames, dtype=int) self.events = TimeSeries(events, discrete=True, name=f"neuron_{self.cell_id}_events") return events def _finalize_reconstruction(self, fps): """Finalize reconstruction: clear cache and compute scaled reconstruction. Parameters ---------- fps : float Sampling rate in Hz (unused but kept for API consistency). """ self._clear_cached_metrics() if self.asp is not None: self._compute_scaled_reconstruction() def _finalize_detection( self, st_inds, end_inds, t_rise, t_off, fps, amplitude_method, event_mask_expansion_sec, create_event_regions, ): """Common finalization for all detection branches. Extracts amplitudes, creates ASP/SP, event regions, and finalizes. Parameters ---------- st_inds : list of int Event start frame indices. end_inds : list of int Event end frame indices. t_rise : float Rise time in frames. t_off : float Decay time in frames. fps : float Sampling rate in Hz. amplitude_method : str 'deconvolution' or 'peak'. event_mask_expansion_sec : float Time in seconds to expand event mask. create_event_regions : bool Whether to create binary event regions. Returns ------- ndarray or None Event regions if create_event_regions=True, else None. """ amplitudes = self._extract_amplitudes( st_inds, end_inds, t_rise, t_off, fps, amplitude_method, event_mask_expansion_sec ) self._create_asp_sp(st_inds, end_inds, amplitudes, t_rise, t_off, fps) events = self._create_event_regions(st_inds, end_inds, fps, create_event_regions) self._finalize_reconstruction(fps) return events def _compute_residual(self, st_inds, end_inds, signal, t_rise, t_off, fps): """Compute residual signal after subtracting quick reconstruction. Used in iterative detection to find events in the residual. Parameters ---------- st_inds : list of int Event start frame indices. end_inds : list of int Event end frame indices. signal : ndarray Current signal to subtract from (should be scaled signal). t_rise : float Rise time in frames. t_off : float Decay time in frames. fps : float Sampling rate in Hz. Returns ------- ndarray Residual signal after subtracting reconstruction. """ if len(st_inds) == 0: return signal.copy() # Extract amplitudes from the CURRENT signal (not original!) # This is critical for iterative mode where signal is the residual # Use already_dff=True since signal is already scaled/normalized baseline_window = int(BASELINE_WINDOW_SEC * fps) quick_amps = Neuron.extract_event_amplitudes( signal, st_inds, end_inds, baseline_window=baseline_window, already_dff=True ) # Create quick ASP quick_asp = np.zeros(self.n_frames) for st_idx, amp in zip(st_inds, quick_amps): if 0 <= st_idx < self.n_frames: quick_asp[st_idx] = amp # Reconstruct and subtract (no scaling needed - signal is already scaled) reconstruction = Neuron.get_restored_calcium(quick_asp, t_rise, t_off) return signal - reconstruction
[docs] def get_mad(self): """Get median absolute deviation of calcium signal. Computes MAD as a robust measure of noise level in the calcium signal. Caches the result for efficiency. Returns ------- float Median absolute deviation of the calcium signal, scaled to be consistent with standard deviation for normally distributed data. Notes ----- MAD is more robust to outliers than standard deviation, making it ideal for noise estimation in calcium imaging data which often contains spike-related transients.""" if self.mad is None: # Try co-computation with SNR if spikes/events available if (self.asp is not None and np.sum(np.abs(self.asp.data)) > 0) or ( self.sp is not None and np.sum(self.sp.data) > 0 ): try: self._calc_snr_simple() # Computes both SNR and MAD except ValueError: # Spikes exist but SNR calc failed - compute MAD independently self.mad = median_abs_deviation(self.ca.data) else: # No spikes - compute MAD independently self.mad = median_abs_deviation(self.ca.data) return self.mad
[docs] def get_snr(self, method="simple"): """Get signal-to-noise ratio of calcium signal. Unified interface for SNR calculation supporting multiple methods. Parameters ---------- method : {'simple', 'wavelet'}, optional Calculation method. Options: - 'simple': Peak-based SNR using spike locations (fast, default) - 'wavelet': Event-based SNR using wavelet regions (accurate) Default is 'simple'. Returns ------- float Signal-to-noise ratio (dimensionless). Higher values indicate stronger signal relative to noise. Raises ------ ValueError If no spike/event data available, or if method is invalid. Notes ----- **Simple method:** - SNR = mean(calcium at spike peaks) / MAD(entire signal) - Fast computation, uses asp (amplitude spikes) or sp (binary spikes) - Caches result in self.snr **Wavelet method:** - SNR = median(event amplitudes) / std(baseline) - More accurate, empirically validated against ground truth - Requires prior wavelet reconstruction - Caches result in self.wavelet_snr Examples -------- >>> neuron.get_snr() # Simple SNR (default) >>> neuron.get_snr(method='simple') # Explicit simple >>> neuron.get_snr(method='wavelet') # Wavelet-based SNR """ if method == "wavelet": return self.get_wavelet_snr() elif method == "simple": if self.snr is None: self._calc_snr_simple() return self.snr else: raise ValueError(f"Invalid method '{method}'. Must be 'simple' or 'wavelet'")
def _calc_snr_simple(self): """Calculate simple peak-based SNR and MAD. Internal method that computes both SNR and MAD in a single pass for efficiency. Preferentially uses amplitude spikes (asp) over binary spikes (sp) for more accurate signal estimation. Returns ------- tuple (snr, mad) where snr is signal-to-noise ratio and mad is median absolute deviation. Raises ------ ValueError If no spikes are present, if MAD is zero, or if SNR calculation results in NaN.""" # Prefer asp (amplitude spikes) over sp (binary spikes) if self.asp is not None and np.sum(np.abs(self.asp.data)) > 0: spk_inds = np.nonzero(self.asp.data)[0] elif self.sp is not None and np.sum(self.sp.data) > 0: spk_inds = np.nonzero(self.sp.data)[0] else: raise ValueError("No spike data available") if len(spk_inds) == 0: raise ValueError("No spikes found!") mad = median_abs_deviation(self.ca.data) if mad == 0: raise ValueError("MAD is zero, cannot compute SNR") sn = np.mean(self.ca.data[spk_inds]) / mad if np.isnan(sn): raise ValueError("Error in SNR calculation") # Cache both values self.snr = sn self.mad = mad return (sn, mad)
[docs] def get_reconstruction_r2(self, event_only=False, n_mad=3.0, use_detected_events=True): """Get R² for calcium reconstruction quality. Computes the coefficient of determination (R²) measuring how well the double-exponential model fits the observed calcium signal. Higher values indicate better model fit. Standard R² is cached. Parameters ---------- event_only : bool, default=False If True, compute R² only in detected event regions. If False, compute standard R² over entire signal (cached). n_mad : float, default=3.0 Number of MAD (Median Absolute Deviation) units above median for event detection. Only used when event_only=True and use_detected_events=False. use_detected_events : bool, default=True If True, use self.events (wavelet-detected event regions) for event mask. If False, use MAD-based threshold on original signal (legacy behavior). Only used when event_only=True. Returns ------- float R² value between -inf and 1. Values closer to 1 indicate better fit. R² > 0.9: Excellent fit (high-quality calcium transients) R² 0.7-0.9: Good fit (acceptable quality) R² 0.5-0.7: Moderate fit (check for artifacts) R² < 0.5: Poor fit (likely artifacts or model mismatch) Raises ------ ValueError If amplitude spike data is not available, or if use_detected_events=True but self.events is None. Notes ----- Standard R² = 1 - (SS_residual / SS_total) computed over entire signal. Event R² = 1 - (SS_residual_events / SS_total_events) in event regions only. When use_detected_events=True (recommended), event regions are defined by self.events (wavelet ridge detection), ensuring alignment with reconstruction. When use_detected_events=False (legacy), event regions are defined by MAD threshold on the original signal. Uses cached RMSE from get_noise_ampl() for efficiency in standard mode. Event-only mode is NOT cached as it depends on parameters. """ if self.asp is None: raise ValueError( "Amplitude spikes required for reconstruction R². Call reconstruct_spikes() first." ) if event_only: t_rise = self.t_rise if self.t_rise is not None else self.default_t_rise t_off = self.t_off if self.t_off is not None else self.default_t_off ca_reconstructed = Neuron.get_restored_calcium(self.asp.data, t_rise, t_off) # Determine event mask source (supports both wavelet and threshold) event_mask = None event_ridges = None if use_detected_events: if self.events is not None: event_mask = self.events.data elif hasattr(self, "threshold_events") and self.threshold_events: event_ridges = self.threshold_events elif self.wvt_ridges: event_ridges = self.wvt_ridges event_r2 = Neuron._calculate_event_r2( self.ca.data, ca_reconstructed, n_mad, event_mask=event_mask, wvt_ridges=event_ridges, fps=self.fps, ) if np.isnan(event_r2): raise ValueError( "Event R² calculation failed. No events detected or insufficient event data. Consider lowering n_mad parameter." ) return event_r2 if self.reconstruction_r2 is None: rmse = self.get_noise_ampl() ss_residual = rmse**2 * self.n_frames ss_total = np.sum((self.ca.data - np.mean(self.ca.data)) ** 2) if ss_total == 0: raise ValueError("Total variance is zero, cannot compute R²") self.reconstruction_r2 = 1 - ss_residual / ss_total return self.reconstruction_r2
[docs] def get_snr_reconstruction(self): """Get reconstruction-based SNR. Computes SNR as the ratio of calcium signal standard deviation to reconstruction error (RMSE). Provides a model-based quality metric complementary to the peak-based SNR. Caches the result. Returns ------- float Reconstruction SNR value (positive). Higher is better. SNR_recon > 20: Excellent reconstruction fidelity SNR_recon 10-20: Good reconstruction fidelity SNR_recon 5-10: Fair reconstruction fidelity SNR_recon < 5: Poor reconstruction (check signal quality) Raises ------ ValueError If spike data is not available for reconstruction. Notes ----- SNR_reconstruction = std(Ca) / RMSE This metric differs from get_snr() which uses peak amplitudes at spike times. Reconstruction SNR measures overall model fit quality. Uses cached RMSE from get_noise_ampl() for efficiency. """ if self.snr_reconstruction is None: rmse = self.get_noise_ampl() signal_std = np.std(self.ca.data) if rmse == 0: raise ValueError("RMSE is zero, cannot compute reconstruction SNR") self.snr_reconstruction = signal_std / rmse return self.snr_reconstruction
[docs] def get_reconstruction_scaled(self, t_rise=None, t_off=None): """Get reconstruction transformed using ca.scdata's scaler. Applies the exact same MinMaxScaler transformation to the reconstruction that was used to create ca.scdata. This ensures consistency for information theory and dimensionality reduction analyses that use ca.scdata. Caches result in self._reconstructed_scaled for efficiency. Cache is cleared when ASP or kinetics change (via _clear_cached_metrics). Parameters ---------- t_rise : float, optional Rise time in seconds. If None, uses optimized or default value. t_off : float, optional Decay time in seconds. If None, uses optimized or default value. Returns ------- ndarray Reconstruction scaled using ca_scaler.transform(). Shape matches ca.data. **Note:** Values may fall outside [0,1] if reconstruction differs from training data (ca.data) in amplitude or baseline. Raises ------ ValueError If spike data (asp) is not available. Notes ----- **Scaling Behavior:** Uses `scaler.transform()` NOT `scaler.fit_transform()`: - Fitted on ca.data during __init__ - Applied to reconstruction (which may exceed ca.data range) - Result can fall outside [0,1] - this is EXPECTED and CORRECT **When values exceed [0,1]:** - recon > data_max → scaled > 1.0 (e.g., missed event in original) - recon < data_min → scaled < 0.0 (e.g., baseline noise, wrong kinetics) This is the mathematically correct behavior for scaler.transform() and preserves the relative scaling needed for information theory. **Use Cases:** - Information theory analyses (mutual information with ca.scdata) - Dimensionality reduction (PCA/ICA on ca.scdata-like signals) - Comparing reconstruction vs ca.scdata in scaled space **DO NOT use for metrics:** R², MAE, RMSE should use ca.data (original scale). Examples -------- >>> neuron.reconstruct_spikes() >>> recon_scaled = neuron.get_reconstruction_scaled() >>> >>> # Check range violations >>> pct_below = 100 * np.sum(recon_scaled < 0) / len(recon_scaled) >>> pct_above = 100 * np.sum(recon_scaled > 1.0) / len(recon_scaled) >>> print(f'Below 0: {pct_below:.1f}%, Above 1: {pct_above:.1f}%') >>> >>> # Use with dimensionality reduction >>> from sklearn.decomposition import PCA >>> pca = PCA(n_components=5) >>> pca.fit(ca.scdata_matrix) # Fit on scaled data >>> recon_proj = pca.transform(recon_scaled.reshape(1, -1)) See Also -------- get_reconstruction_r2 : R² using ca.data (for metrics) get_mae : MAE using ca.data (for metrics) """ if self.asp is None: raise ValueError("No spike data available. Call reconstruct_spikes() first.") if self._reconstructed_scaled is not None and t_rise is None and t_off is None: return self._reconstructed_scaled if t_rise is None: t_rise = self.t_rise if self.t_rise is not None else self.default_t_rise if t_off is None: t_off = self.t_off if self.t_off is not None else self.default_t_off ca_recon = Neuron.get_restored_calcium(self.asp.data, t_rise, t_off) ca_recon_scaled = self.ca_scaler.transform(ca_recon.reshape(-1, 1)).reshape(-1) if t_rise == ( self.t_rise if self.t_rise is not None else self.default_t_rise ) and t_off == (self.t_off if self.t_off is not None else self.default_t_off): self._reconstructed_scaled = ca_recon_scaled return ca_recon_scaled
[docs] def get_mae(self): """Get Mean Absolute Error between observed and reconstructed calcium. Computes MAE as the mean absolute deviation between observed calcium signal and reconstruction from detected spikes. Provides intuitive measure of reconstruction quality in original signal units (ΔF/F). Caches the result. Returns ------- float MAE value (non-negative). Lower is better. MAE < 0.1: Excellent reconstruction (typical for high SNR) MAE 0.1-0.2: Good reconstruction (acceptable quality) MAE 0.2-0.3: Moderate reconstruction (check for issues) MAE > 0.3: Poor reconstruction (likely artifacts or failures) Raises ------ ValueError If amplitude spike data is not available. Notes ----- MAE = mean(\|Ca_observed - Ca_reconstructed\|) Unlike RMSE, MAE treats all errors equally (no squaring). Useful for understanding typical deviation and detecting outlier sensitivity: - If RMSE >> MAE: Few large errors dominate (e.g., missed events) - If RMSE ≈ MAE: Errors uniformly distributed (e.g., white noise) The MAE/RMSE ratio can diagnose error distribution patterns. """ if self.mae is None: if self.asp is None: raise ValueError( "Amplitude spikes required for MAE calculation. Call reconstruct_spikes() first." ) t_rise = self.t_rise if self.t_rise is not None else self.default_t_rise t_off = self.t_off if self.t_off is not None else self.default_t_off ca_fitted = Neuron.get_restored_calcium(self.asp.data, t_rise, t_off) self.mae = np.mean(np.abs(self.ca.data - ca_fitted)) return self.mae
[docs] def get_event_rmse(self, n_mad=4.0, use_detected_events=True): """Get RMSE during event periods only. Measures reconstruction error only during calcium transient events, ignoring baseline noise. Provides a more accurate measure of deconvolution quality than full-signal RMSE, which is dominated by baseline regions that reconstruction should NOT fit. Parameters ---------- n_mad : float, default=4.0 Number of MAD (Median Absolute Deviation) units above median for event detection threshold. Only used if use_detected_events=False. use_detected_events : bool, default=True If True, use self.events (wavelet-detected) for event mask. If False, use MAD-based threshold (legacy behavior). Returns ------- float Event RMSE (lower is better). Same units as calcium signal (ΔF/F). Event RMSE < 0.05: Excellent event reconstruction Event RMSE 0.05-0.10: Good event reconstruction Event RMSE 0.10-0.15: Moderate event reconstruction Event RMSE > 0.15: Poor event reconstruction Raises ------ ValueError If amplitude spike data is not available or no events detected. Notes ----- Event_RMSE = sqrt(mean((Ca_events - Recon_events)^2)) Where events are defined as: Ca > median + n_mad * MAD Complementary to Event R² - while Event R² shows proportion of variance explained, Event RMSE shows absolute reconstruction error magnitude. """ if self.asp is None: raise ValueError( "Amplitude spikes required for Event RMSE calculation. Call reconstruct_spikes() first." ) t_rise = self.t_rise if self.t_rise is not None else self.default_t_rise t_off = self.t_off if self.t_off is not None else self.default_t_off ca_reconstructed = Neuron.get_restored_calcium(self.asp.data, t_rise, t_off) # Determine event mask if use_detected_events and self.events is not None: event_mask = self.events.data > 0 else: # MAD-based threshold (legacy) median = np.median(self.ca.scdata) mad = np.median(np.abs(self.ca.scdata - median)) * MAD_SCALE_FACTOR threshold = median + n_mad * mad event_mask = self.ca.scdata > threshold if np.sum(event_mask) == 0: raise ValueError( f"""No events detected. Consider lowering n_mad parameter or check wavelet detection.""" ) ca_events = self.ca.data[event_mask] recon_events = ca_reconstructed[event_mask] residuals = ca_events - recon_events event_rmse = float(np.sqrt(np.mean(residuals**2))) return event_rmse
[docs] def get_event_mae(self, n_mad=4.0, use_detected_events=True): """Get MAE during event periods only. Measures reconstruction error only during calcium transient events, ignoring baseline noise. Provides a more accurate measure of deconvolution quality than full-signal MAE, which is dominated by baseline regions that reconstruction should NOT fit. Parameters ---------- n_mad : float, default=4.0 Number of MAD (Median Absolute Deviation) units above median for event detection threshold. Only used if use_detected_events=False. use_detected_events : bool, default=True If True, use self.events (wavelet-detected) for event mask. If False, use MAD-based threshold (legacy behavior). Returns ------- float Event MAE (lower is better). Same units as calcium signal (ΔF/F). Event MAE < 0.05: Excellent event reconstruction Event MAE 0.05-0.10: Good event reconstruction Event MAE 0.10-0.15: Moderate event reconstruction Event MAE > 0.15: Poor event reconstruction Raises ------ ValueError If amplitude spike data is not available or no events detected. Notes ----- Event_MAE = mean(\|Ca_events - Recon_events\|) Where events are defined as: Ca > median + n_mad * MAD Unlike Event RMSE, Event MAE does not square errors, making it less sensitive to outliers. Useful for understanding typical deviation: - If Event_RMSE >> Event_MAE: Few large errors (missed/false events) - If Event_RMSE ≈ Event_MAE: Errors uniformly distributed """ if self.asp is None: raise ValueError( "Amplitude spikes required for Event MAE calculation. Call reconstruct_spikes() first." ) t_rise = self.t_rise if self.t_rise is not None else self.default_t_rise t_off = self.t_off if self.t_off is not None else self.default_t_off ca_reconstructed = Neuron.get_restored_calcium(self.asp.data, t_rise, t_off) # Determine event mask if use_detected_events and self.events is not None: event_mask = self.events.data > 0 else: # MAD-based threshold (legacy) median = np.median(self.ca.scdata) mad = np.median(np.abs(self.ca.scdata - median)) * MAD_SCALE_FACTOR threshold = median + n_mad * mad event_mask = self.ca.scdata > threshold if np.sum(event_mask) == 0: raise ValueError( f"""No events detected. Consider lowering n_mad parameter or check wavelet detection.""" ) ca_events = self.ca.data[event_mask] recon_events = ca_reconstructed[event_mask] event_mae = float(np.mean(np.abs(ca_events - recon_events))) return event_mae
[docs] def get_event_count(self=None, n_mad=None): """Get count of detected spike events (ridges passing thresholds). Counts the number of spike events detected by wavelet reconstruction that pass all thresholds. Returns count of non-zero entries in amplitude spike data. Result is cached for efficiency. Parameters ---------- n_mad : float, default=3.0 Not currently used - reserved for future threshold filtering. Kept for API consistency with other event-based methods. Returns ------- int Number of detected spike events (ridges). Raises ------ ValueError If spike reconstruction has not been performed. Notes ----- Counts non-zero entries in self.asp.data (amplitude spike array). Each non-zero entry represents a detected calcium event that passed wavelet ridge filtering and amplitude thresholds during reconstruction. Useful for: - Assessing detection sensitivity - Computing event-based statistics - Validating reconstruction parameters """ if self.event_count is None: if self.asp is None: raise ValueError( "Spike reconstruction required for event counting. Call reconstruct_spikes() first." ) self.event_count = int(np.count_nonzero(self.asp.data)) return self.event_count
[docs] def get_baseline_noise_std(self, n_mad=3.0): """Get baseline noise standard deviation from residual signal. Estimates noise level as the standard deviation of reconstruction residuals (original - reconstructed). This measures the actual reconstruction error, including measurement noise, unmodeled dynamics, and reconstruction inaccuracies. Parameters ---------- n_mad : float, default=3.0 Not used in this implementation. Kept for API compatibility. Returns ------- float Standard deviation of residual signal. Raises ------ ValueError If reconstruction has not been performed. Notes ----- Residual = Ca_original - Ca_reconstructed The residual contains: - Measurement noise - Unmodeled small events - Drift and baseline fluctuations - Reconstruction errors This is the proper noise estimate for normalized metrics (NMAE, NRMSE). """ if self.asp is None: raise ValueError( "Spike reconstruction required for noise estimation. Call reconstruct_spikes() first." ) recon = self.reconstructed if recon is None: raise ValueError("Reconstruction failed or returned None") residuals = self.ca.data - recon.data return float(np.std(residuals))
[docs] def get_event_snr(self, n_mad=4.0): """Get event SNR in dB (signal quality metric). Computes SNR as the ratio of mean event amplitude to baseline noise std, expressed in decibels. Measures how clearly calcium events stand out from baseline noise. Higher values indicate cleaner signals. Parameters ---------- n_mad : float, default=3.0 Number of MAD units above median for event detection threshold. Returns ------- float Event SNR in decibels (dB). Higher is better. SNR > 15 dB: Excellent signal quality SNR 10-15 dB: Good signal quality SNR 5-10 dB: Moderate signal quality SNR < 5 dB: Poor signal quality (noisy) Raises ------ ValueError If no events detected or baseline noise is zero. Notes ----- SNR_dB = 20 * log10(mean(events) / std(baseline)) This metric assesses signal quality independent of reconstruction. Low SNR suggests noisy data or detection threshold issues. """ if self.ca is None or len(self.ca.data) == 0: raise ValueError("Calcium signal data required for event SNR") median = np.median(self.ca.data) mad = np.median(np.abs(self.ca.data - median)) * MAD_SCALE_FACTOR threshold = median + n_mad * mad event_mask = self.ca.data > threshold baseline_mask = ~event_mask event_data = self.ca.data[event_mask] baseline_data = self.ca.data[baseline_mask] if len(event_data) == 0: raise ValueError("No events detected. Consider lowering n_mad parameter.") if len(baseline_data) < 10: raise ValueError("Insufficient baseline data for noise estimation.") event_mean = np.mean(event_data) baseline_std = np.std(baseline_data) if baseline_std == 0: raise ValueError("Baseline noise std is zero, cannot compute SNR") snr_linear = event_mean / baseline_std return float(20 * np.log10(snr_linear))
[docs] def get_wavelet_snr(self): """Get event-based signal-to-noise ratio. Uses detected event regions to separate signal from baseline, providing accurate SNR measurement that accounts for event timing and shape. Works with both wavelet and threshold detection methods. Returns ------- float Event SNR (signal_strength / baseline_noise). Higher values indicate better signal quality. Raises ------ ValueError If reconstruction not performed with create_event_regions=True, no events detected, insufficient data (< 3 events), or baseline noise is zero. Notes ----- Requires prior call to: neuron.reconstruct_spikes(method='wavelet', create_event_regions=True) OR neuron.reconstruct_spikes(method='threshold', create_event_regions=True) SNR calculation: 1. Baseline: median and MAD from non-event frames 2. Signal: median of peak amplitudes across all events 3. SNR = (signal - baseline_median) / baseline_noise Uses peak amplitudes (not event medians) to correctly handle sparse high-amplitude events. See Also -------- get_snr : Simple SNR based on spike times get_event_snr : Alias for this method reconstruct_spikes : Spike reconstruction (wavelet or threshold) Examples -------- >>> neuron = Neuron(cell_id=0, ca=calcium_data, sp=None) >>> neuron.reconstruct_spikes(method='threshold', create_event_regions=True) >>> snr = neuron.get_event_snr() # or get_wavelet_snr() >>> print(f"Signal quality (SNR): {snr:.2f}") """ if self.wavelet_snr is None: self.wavelet_snr = self._calc_wavelet_snr() return self.wavelet_snr
# Alias for method-agnostic naming get_event_snr = get_wavelet_snr def _calc_wavelet_snr(self): """Calculate event-based SNR using detected event regions. Internal method that computes SNR from detected event regions. Works with both wavelet and threshold detection methods. Uses peak amplitudes to handle sparse high-amplitude events correctly. Returns ------- float Event SNR value. Raises ------ ValueError If events not detected, insufficient data, or baseline noise is zero. """ # Try self.events first, then construct from threshold_events or wvt_ridges events_mask = None if self.events is not None and self.events.data is not None: events_mask = self.events.data.astype(bool) elif hasattr(self, "threshold_events") and self.threshold_events: # Construct mask from threshold events events_mask = np.zeros(len(self.ca.data), dtype=bool) for event in self.threshold_events: st, end = int(event.start), min(int(event.end), len(self.ca.data)) events_mask[st:end] = True elif self.wvt_ridges: # Construct mask from wavelet ridges events_mask = np.zeros(len(self.ca.data), dtype=bool) for ridge in self.wvt_ridges: st, end = int(ridge.start), min(int(ridge.end), len(self.ca.data)) events_mask[st:end] = True else: raise ValueError( "No event regions detected. " "Call reconstruct_spikes(create_event_regions=True) first, " "or use detect_events_threshold() to detect events." ) ca = self.ca.data # Check if any events detected if not np.any(events_mask): raise ValueError("No events in event mask") # Calculate baseline from non-event regions baseline_mask = ~events_mask baseline_values = ca[baseline_mask] if len(baseline_values) < 10: raise ValueError( f"Insufficient baseline frames ({len(baseline_values)}). " "Need at least 10 baseline frames." ) baseline_median = np.median(baseline_values) baseline_noise = median_abs_deviation(baseline_values, scale="normal") if baseline_noise == 0: raise ValueError("Baseline noise is zero, cannot compute SNR") # Identify individual events (transitions in mask) event_starts = np.where(np.diff(events_mask.astype(int)) == 1)[0] + 1 event_ends = np.where(np.diff(events_mask.astype(int)) == -1)[0] + 1 # Handle edge cases if events_mask[0]: event_starts = np.concatenate([[0], event_starts]) if events_mask[-1]: event_ends = np.concatenate([event_ends, [len(events_mask)]]) n_events = min(len(event_starts), len(event_ends)) if n_events < 3: raise ValueError( f"Too few events detected ({n_events}). " "Need at least 3 events for reliable SNR calculation." ) # Extract peak amplitude for each event event_amplitudes = [] for i in range(n_events): event_region = ca[event_starts[i] : event_ends[i]] if len(event_region) > 0: event_amplitudes.append(np.max(event_region)) event_amplitudes = np.array(event_amplitudes) if len(event_amplitudes) == 0: raise ValueError("No valid event amplitudes extracted") # Signal strength using peak amplitudes signal_strength = np.median(event_amplitudes) - baseline_median # Robust SNR snr_wavelet = signal_strength / baseline_noise return float(snr_wavelet)
[docs] def get_nmae(self, n_mad=3.0): """Get Normalized Mean Absolute Error (MAE / baseline_noise_std). Computes MAE divided by baseline noise standard deviation. This normalization accounts for varying noise levels across neurons, enabling fair quality comparison. Values indicate error magnitude relative to baseline noise level. Parameters ---------- n_mad : float, default=3.0 Number of MAD units above median for baseline detection. Returns ------- float Normalized MAE (dimensionless ratio). Lower is better. NMAE < 1.0: Error less than baseline noise (excellent) NMAE < 2.0: Error ~2× baseline noise (good) NMAE < 3.0: Error ~3× baseline noise (moderate) NMAE > 3.0: Error >> baseline noise (poor) Raises ------ ValueError If baseline noise cannot be estimated or is zero. Notes ----- NMAE = MAE / std(baseline) Unlike raw MAE, NMAE is comparable across neurons with different noise levels. Interpretation: error magnitude in "noise units". """ mae = self.get_mae() baseline_std = self.get_baseline_noise_std(n_mad=n_mad) if baseline_std == 0: raise ValueError("Baseline noise std is zero, cannot normalize MAE") return float(mae / baseline_std)
[docs] def get_nrmse(self, n_mad=3.0): """Get Normalized RMSE (RMSE / baseline_noise_std). Computes RMSE divided by baseline noise standard deviation. This normalization accounts for varying noise levels across neurons. More sensitive to large errors than NMAE due to squaring. Parameters ---------- n_mad : float, default=3.0 Number of MAD units above median for baseline detection. Returns ------- float Normalized RMSE (dimensionless ratio). Lower is better. NRMSE < 1.0: Error less than baseline noise (excellent) NRMSE < 2.0: Error ~2× baseline noise (good) NRMSE < 3.0: Error ~3× baseline noise (moderate) NRMSE > 3.0: Error >> baseline noise (poor) Raises ------ ValueError If baseline noise cannot be estimated or is zero. Notes ----- NRMSE = RMSE / std(baseline) NRMSE > NMAE indicates few large errors dominate (e.g., missed events). NRMSE ≈ NMAE indicates uniformly distributed errors. """ rmse = self.get_noise_ampl() baseline_std = self.get_baseline_noise_std(n_mad) if baseline_std == 0: raise ValueError("Baseline noise std is zero, cannot normalize RMSE") return float(rmse / baseline_std)
[docs] def get_kinetics( self, method="direct", fps=20, use_cached=True, update_reconstruction=True, **kwargs ): """Get optimized calcium kinetics parameters (t_rise, t_off). Simple wrapper around optimize_kinetics() for easy access to optimized parameters. Caches results for efficiency - only runs optimization once per neuron. Updates neuron attributes: - self.t_rise: optimized rise time (frames) - self.t_off: optimized decay time (frames) - self._kinetics_info: full optimization results dict Parameters ---------- method : str, optional Optimization method. Currently only 'direct' is supported. Default: 'direct'. fps : float, optional Sampling rate in frames per second. Default: 20.0. use_cached : bool, optional If True and optimization already run, return cached results. If False, always re-run optimization. Default: True. update_reconstruction : bool, optional If True, recompute reconstruction with optimized kinetics. Default: True. **kwargs : dict, optional Additional arguments passed to optimize_kinetics() (e.g., t_rise_range, t_off_range, ftol, gtol, maxiter, etc.) Returns ------- dict Dictionary with keys: - \'t_rise\': float, rise time in seconds - \'t_off\': float, decay time in seconds - \'t_rise_frames\': float, rise time in frames - \'t_off_frames\': float, decay time in frames - \'event_r2\': float, event R² with optimized parameters - \'default_r2\': float, event R² with default parameters - \'improvement\': float, R² improvement - \'method\': str, optimization method used - (additional fields from optimize_kinetics()) Raises ------ ValueError If no events detected (call reconstruct_spikes() first). Examples -------- >>> # Optimize kinetics and access via attributes >>> neuron = Neuron(cell_id=0, ca=calcium_trace, sp=None, fps=20) >>> neuron.reconstruct_spikes(method=\'wavelet\') >>> result = neuron.get_kinetics(fps=20) >>> >>> # Access optimized parameters via attributes (in frames) >>> t_rise_frames = neuron.t_rise >>> t_off_frames = neuron.t_off >>> >>> # Or via returned dict (in seconds) >>> t_rise_sec = result[\'t_rise\'] >>> t_off_sec = result[\'t_off\'] >>> print(f"R² improvement: {result[\'improvement\']:+.3f}") >>> >>> # Access full optimization info >>> print(f"Converged: {neuron._kinetics_info[\'converged\']}") >>> print(f"Time: {neuron._kinetics_info[\'time\']:.2f}s") Notes ----- - Results are cached in self._kinetics_info after first call - Optimized parameters stored in self.t_rise and self.t_off (frames) - Uses fast direct derivative measurement method - For backward compatibility: neuron.t_off now contains optimized value """ if use_cached and hasattr(self, "_kinetics_info") and self._kinetics_info is not None: return self._kinetics_info # Run optimization and optionally update reconstruction with new kinetics result = self.optimize_kinetics( method=method, fps=fps, update_reconstruction=update_reconstruction, **kwargs ) # Cache the result self._kinetics_info = result return result
[docs] def get_t_off(self): """Get calcium decay time constant. .. deprecated:: 0.5.0 Use :meth:`get_kinetics` instead for better optimization that jointly optimizes both t_rise and t_off using correct event R² metric. This method only fits t_off using simple MSE and is significantly slower. Fits the decay time constant by optimizing the match between observed calcium and reconstructed calcium from spikes. Caches the result for efficiency. Returns ------- float Decay time constant in frames. Raises ------ ValueError If spike data is not available for fitting. Notes ----- Uses scipy.optimize.minimize to find the optimal t_off value that minimizes the RMSE between observed and reconstructed calcium. **DEPRECATED**: Use get_kinetics() instead: >>> # Old way (deprecated) >>> t_off = neuron.get_t_off() >>> # New way (recommended) >>> kinetics = neuron.get_kinetics() >>> t_off = kinetics['t_off_frames'] # in frames >>> t_rise = kinetics['t_rise_frames'] # bonus: also get t_rise """ import warnings warnings.warn( "get_t_off() is deprecated and will be removed in v0.6.0. Use get_kinetics() instead for better joint optimization of t_rise and t_off. Example: kinetics = neuron.get_kinetics(); t_off = kinetics['t_off_frames']", DeprecationWarning, stacklevel=2, ) if self.t_off is None: (self.t_off, self.noise_ampl) = self._fit_t_off() return self.t_off
[docs] def get_noise_ampl(self): """Get noise amplitude estimate from calcium-spike reconstruction. Returns the root mean square error (RMSE) between observed calcium and reconstructed calcium from spikes using current kinetics parameters. This provides an estimate of the noise level in the calcium signal after accounting for spike-related transients. Returns ------- float RMSE between observed and reconstructed calcium signal. Raises ------ ValueError If spike data is not available for reconstruction. Notes ----- Uses cached reconstruction from get_reconstructed() which respects optimized kinetics (self.t_rise, self.t_off). This is different from MAD, as it specifically measures the residual error after spike-to-calcium reconstruction. The value is cached after first computation. """ if self.noise_ampl is None: # Prefer asp, fall back to sp for backward compatibility if self.asp is None and self.sp is None: raise ValueError( "Spike reconstruction required for noise amplitude. Call reconstruct_spikes() first." ) recon = self.get_reconstructed() if recon is None: raise ValueError("Reconstruction failed or returned None") residuals = self.ca.data - recon.data self.noise_ampl = float(np.sqrt(np.mean(residuals**2))) return self.noise_ampl
def _fit_t_off(self): """Fit optimal calcium decay time constant from spike-calcium pairs. Uses scipy.optimize.minimize to find the t_off value that minimizes the RMSE between observed and reconstructed calcium signals. Returns ------- tuple (t_off, rmse) where t_off is the optimal decay time constant in frames (capped at 5x default) and rmse is the reconstruction error. Raises ------ ValueError If spike data is not available. Notes ----- If the fitted t_off exceeds 5x the default value, it is capped and a warning is logged, as this typically indicates signal quality issues.""" if self.asp is not None and np.sum(np.abs(self.asp.data)) > 0: spike_data = self.asp.data elif self.sp is not None: spike_data = self.sp.data else: raise ValueError("Spike data required for t_off fitting") res = minimize( Neuron.ca_mse_error, np.array([self.default_t_off]), args=(self.ca.data, spike_data, self.default_t_rise), bounds=[(self.default_t_rise * 1.1, None)], ) opt_t_off = res.x[0] noise_amplitude = res.fun logger = logging.getLogger(__name__) if opt_t_off <= self.default_t_rise: logger.warning( f"""Optimization failed for neuron {self.cell_id}: fitted t_off ({opt_t_off:.2f}) <= t_rise ({self.default_t_rise:.2f}). Using default t_off={self.default_t_off}""" ) opt_t_off = self.default_t_off elif opt_t_off > self.default_t_off * 5: logger.warning( f"""Calculated t_off={int(opt_t_off)} for neuron {self.cell_id} is suspiciously high, check signal quality. t_off has been automatically lowered to {self.default_t_off * 5}""" ) opt_t_off = self.default_t_off * 5 if not res.success: logger.warning( f"""Optimization did not converge for neuron {self.cell_id}: {res.message}. Using fitted value {opt_t_off:.2f} with caution.""" ) return (opt_t_off, noise_amplitude)
[docs] def get_shuffled_calcium(self, method="roll_based", return_array=True, seed=None, **kwargs): """Get shuffled calcium signal using various randomization methods. Creates surrogate data that preserves certain statistical properties of the original calcium signal while destroying temporal relationships. Parameters ---------- method : {'roll_based', 'waveform_based', 'chunks_based'}, optional Shuffling method to use: - 'roll_based': Circular shift by random offset - 'waveform_based': Shuffle spikes then reconstruct calcium - 'chunks_based': Divide signal into chunks and reorder Default is 'roll_based'. return_array : bool, optional If True, return numpy array. If False, return TimeSeries object. Default is True. seed : int, optional Random seed for reproducible shuffling. **kwargs Additional arguments passed to shuffling method: - For chunks_based: n (int) - number of chunks Returns ------- ndarray or TimeSeries Shuffled calcium signal with same length as original. Raises ------ AttributeError If the specified method does not exist. Notes ----- Different methods preserve different signal properties: - roll_based: Preserves all autocorrelations - waveform_based: Preserves spike waveform shapes - chunks_based: Preserves local signal structure within chunks""" valid_methods = ["roll_based", "waveform_based", "chunks_based"] if method not in valid_methods: raise ValueError(f"""Invalid method \'{method}\'. Must be one of {valid_methods}""") fn = getattr(self, f"""_shuffle_calcium_data_{method}""") # Call the shuffling method shuffled_data = fn(seed=seed, **kwargs) # Return as array or TimeSeries based on return_array parameter if return_array: return shuffled_data else: from ..information.info_base import TimeSeries return TimeSeries(data=shuffled_data, discrete=False, name=f"neuron_{self.cell_id}_shuffled_ca")
def _shuffle_calcium_data_waveform_based(self, seed=None, **kwargs): """Shuffle calcium by reconstructing from ISI-shuffled spikes. Preserves spike waveform shapes while randomizing spike timing based on inter-spike interval statistics. Parameters ---------- seed : int, optional Random seed for reproducible shuffling. **kwargs Additional arguments (for compatibility with shuffle interface). Returns ------- ndarray Shuffled calcium signal. Raises ------ ValueError If spike data is not available.""" if self.asp is not None and np.sum(np.abs(self.asp.data)) > 0: spike_data = self.asp.data elif self.sp is not None: spike_data = self.sp.data else: raise ValueError("Spike data required for waveform-based shuffling") # Waveform-based shuffling REQUIRES accurate kinetics for proper background extraction # Optimize if not already done if self.t_rise is None or self.t_off is None: try: self.get_kinetics(fps=self.fps) except (ValueError, AttributeError): # Optimization failed - use defaults import warnings warnings.warn( f"Kinetics optimization failed for neuron {self.cell_id}. " "Using default kinetics for waveform-based shuffling. " "Results may be less accurate.", UserWarning, ) # Use optimized kinetics if available, otherwise fall back to defaults opt_t_rise = self.t_rise if self.t_rise is not None else self.default_t_rise opt_t_off = self.t_off if self.t_off is not None else self.default_t_off conv = Neuron.get_restored_calcium(spike_data, opt_t_rise, opt_t_off) background = self.ca.data - conv[: len(self.ca.data)] pspk = self._shuffle_spikes_data_isi_based(seed=seed) if self.asp is not None and np.sum(np.abs(self.asp.data)) > 0: amp_indices = np.nonzero(self.asp.data)[0] amplitudes = self.asp.data[amp_indices] pspk_scaled = pspk.astype(float) pspk_indices = np.nonzero(pspk)[0] if len(pspk_indices) == len(amplitudes): pspk_scaled[pspk_indices] = amplitudes elif len(pspk_indices) > 0: pspk_scaled[pspk_indices] = np.mean(amplitudes) psconv = Neuron.get_restored_calcium(pspk_scaled, opt_t_rise, opt_t_off) else: psconv = Neuron.get_restored_calcium(pspk, opt_t_rise, opt_t_off) shuf_ca = psconv[: len(self.ca.data)] + background return shuf_ca def _shuffle_calcium_data_chunks_based(self, n=100, seed=None, **kwargs): """Shuffle calcium by dividing into chunks and reordering. Preserves local calcium dynamics within chunks while destroying long-range temporal relationships. Parameters ---------- n : int, optional Number of chunks to divide signal into. Must be positive. Default is 100. seed : int, optional Random seed for reproducible shuffling. **kwargs Additional keyword arguments (unused, for API compatibility). Returns ------- ndarray Shuffled calcium signal with same length as original. Notes ----- - Chunks may have unequal sizes if signal length not divisible by n - Preserves local dynamics (within-chunk patterns) - Destroys global temporal structure - Useful for testing significance of long-range correlations""" check_positive(n=n) if seed is not None: np.random.seed(seed) ca = self.ca.data chunks = np.array_split(ca, n) inds = np.arange(len(chunks)) np.random.shuffle(inds) shuf_ca = np.concatenate([chunks[i] for i in inds]) return shuf_ca def _shuffle_calcium_data_roll_based(self, shift=None, seed=None, **kwargs): """Shuffle calcium by circular shift (rolling). Preserves all autocorrelations and power spectrum while destroying temporal relationships with external signals. Parameters ---------- shift : int, optional Shift amount in frames. If None, randomly chosen between 3*t_off and n_frames-3*t_off. seed : int, optional Random seed for reproducible shuffling. **kwargs Additional arguments (for compatibility with shuffle interface). Returns ------- ndarray Circularly shifted calcium signal. Raises ------ ValueError If signal is too short for valid shuffling range.""" # Roll-based shuffling only needs t_off for shift range calculation # Constant offset is sufficient (doesn't need precise optimization) opt_t_off = self.t_off if self.t_off is not None else self.default_t_off if shift is None: if seed is not None: np.random.seed(seed) min_shift = int(3 * opt_t_off) max_shift = self.n_frames - int(3 * opt_t_off) if min_shift >= max_shift: raise ValueError( f"""Signal too short for roll-based shuffling. Need at least {2 * int(3 * opt_t_off)} frames, but have {self.n_frames}""" ) shift = np.random.randint(min_shift, max_shift) elif not isinstance(shift, (int, np.integer)): raise ValueError(f"""shift must be integer, got {type(shift).__name__}""") if shift < 0 or shift >= self.n_frames: raise ValueError(f"""shift must be in range [0, {self.n_frames - 1}], got {shift}""") shuf_ca = np.roll(self.ca.data, shift) return shuf_ca
[docs] def get_shuffled_spikes(self, method="isi_based", return_array=True, seed=None, **kwargs): """Get shuffled spike train. Creates surrogate spike data that preserves certain statistical properties while destroying temporal relationships. Parameters ---------- method : {'isi_based'}, optional Shuffling method to use. Currently only 'isi_based' is supported, which preserves inter-spike interval statistics. Default is 'isi_based'. return_array : bool, optional If True, return numpy array. If False, return TimeSeries object. Default is True. seed : int, optional Random seed for reproducible shuffling. **kwargs Additional arguments passed to shuffling method. Returns ------- ndarray or TimeSeries Shuffled spike train with same number of spikes as original. Raises ------ AttributeError If no spike data is available. ValueError If method is not recognized. Notes ----- The ISI-based method preserves the distribution of inter-spike intervals while randomizing spike positions.""" if self.sp is None: raise AttributeError("Unable to shuffle spikes without spikes data") valid_methods = ["isi_based"] if method not in valid_methods: raise ValueError(f"""Invalid method \'{method}\'. Must be one of {valid_methods}""") fn = getattr(self, f"""_shuffle_spikes_data_{method}""") # Call the shuffling method shuffled_data = fn(seed=seed, **kwargs) # Return as array or TimeSeries based on return_array parameter if return_array: return shuffled_data else: from ..information.info_base import TimeSeries return TimeSeries(data=shuffled_data, discrete=True, name=f"neuron_{self.cell_id}_shuffled_sp") # discrete=True for spikes
def reconstructed(self): """Get reconstructed calcium signal from amplitude spikes (cached). Convenience property that returns cached reconstruction. For more control, use get_reconstructed() method instead. Returns ------- TimeSeries or None Reconstructed calcium signal as TimeSeries object. Returns None if amplitude spike data is not available. See Also -------- get_reconstructed : Method with force_reconstruction and custom parameters """ return self.get_reconstructed() reconstructed = property(reconstructed)
[docs] def get_reconstructed(self, force_reconstruction=False, **kwargs): """Get reconstructed calcium signal from amplitude spikes. Lazily computes and caches the reconstructed calcium signal by convolving detected amplitude spikes with the calcium kernel. Parameters ---------- force_reconstruction : bool, optional If True, bypass cache and force recomputation. Default is False. **kwargs : dict, optional Custom reconstruction parameters: - t_rise_frames : float, optional Custom rise time in frames. If not provided, uses optimized t_rise from get_kinetics(). - t_off_frames : float, optional Custom decay time in frames. If not provided, uses optimized t_off from get_kinetics(). - spike_data : array-like, optional Custom spike data to reconstruct from. If not provided, uses self.asp.data. Returns ------- TimeSeries or None Reconstructed calcium signal as TimeSeries object. Returns None if amplitude spike data is not available (call reconstruct_spikes() first) and no custom spike_data provided. Notes ----- The default reconstruction uses: - self.asp.data: Amplitude spikes (dF/F units) - self.t_rise: Optimized rise time from get_kinetics() (lazy, cached) - self.t_off: Optimized decay time from get_kinetics() (lazy, cached) The reconstructed signal can be compared with self.ca to assess reconstruction quality. Use get_reconstruction_r2(), get_mae(), or get_snr_reconstruction() for quantitative metrics. Custom parameters are NOT cached. Only default reconstruction is cached. Examples -------- >>> neuron = Neuron(cell_id=1, ca=calcium_data, sp=None, fps=20) >>> neuron.reconstruct_spikes(method='wavelet') >>> >>> # Get cached default reconstruction >>> recon = neuron.reconstructed >>> >>> # Force recomputation >>> recon = neuron.get_reconstructed(force_reconstruction=True) >>> >>> # Use custom decay time >>> recon_fast = neuron.get_reconstructed(t_off_frames=20) >>> recon_slow = neuron.get_reconstructed(t_off_frames=60) """ has_custom_params = bool(kwargs) if has_custom_params or force_reconstruction: spike_data = kwargs.get("spike_data", None) if spike_data is None: # Prefer asp, fall back to sp for backward compatibility if self.asp is not None: spike_data = self.asp.data elif self.sp is not None: import warnings warnings.warn( "Using binary spikes (sp) instead of amplitude spikes (asp). " "Reconstruction will use uniform amplitudes. " "Call reconstruct_spikes() to get amplitude-based reconstruction.", UserWarning, ) spike_data = self.sp.data else: return None t_rise_frames = kwargs.get("t_rise_frames", None) t_off_frames = kwargs.get("t_off_frames", None) if t_rise_frames is None: t_rise_frames = self.t_rise if self.t_rise is not None else self.default_t_rise if t_off_frames is None: t_off_frames = self.t_off if self.t_off is not None else self.default_t_off reconstructed_data = Neuron.get_restored_calcium( spike_data, t_rise_frames, t_off_frames ) return TimeSeries(reconstructed_data, discrete=False, name=f"neuron_{self.cell_id}_reconstructed") if self._reconstructed is None: # Prefer asp, fall back to sp for backward compatibility if self.asp is not None: spike_data = self.asp.data elif self.sp is not None: import warnings warnings.warn( "Using binary spikes (sp) instead of amplitude spikes (asp). " "Reconstruction will use uniform amplitudes. " "Call reconstruct_spikes() to get amplitude-based reconstruction.", UserWarning, ) spike_data = self.sp.data else: return None t_rise = self.t_rise if self.t_rise is not None else self.default_t_rise t_off = self.t_off if self.t_off is not None else self.default_t_off reconstructed_data = Neuron.get_restored_calcium(spike_data, t_rise, t_off) self._reconstructed = TimeSeries(reconstructed_data, discrete=False, name=f"neuron_{self.cell_id}_reconstructed") return self._reconstructed
def _shuffle_spikes_data_isi_based(self, seed=None): """Shuffle spikes preserving inter-spike interval statistics. Randomizes spike positions while maintaining the distribution of time intervals between spikes. Parameters ---------- seed : int, optional Random seed for reproducible shuffling. Returns ------- ndarray Shuffled spike train with same length and number of spikes. Binary array where 1 indicates spike. Notes ----- - Preserves ISI distribution but not ISI sequence order - First spike position is randomized within valid range - Handles edge cases: empty spike trains, boundary conditions - May produce different temporal patterns despite same ISI distribution""" if seed is not None: np.random.seed(seed) nfr = self.n_frames pseudo_spikes = np.zeros(nfr) event_inds = np.where(self.sp.data != 0)[0] if len(event_inds) == 0: return self.sp.data event_vals = self.sp.data[event_inds].copy() event_range = max(event_inds) - min(event_inds) max_start = max(1, nfr - event_range - 1) first_random_pos = np.random.choice(max_start) interspike_intervals = np.diff(event_inds) rng = np.arange(len(interspike_intervals)) np.random.shuffle(rng) disordered_interspike_intervals = interspike_intervals[rng] pseudo_event_inds = np.cumsum( np.insert(disordered_interspike_intervals, 0, first_random_pos) ) valid_mask = pseudo_event_inds < nfr pseudo_event_inds = pseudo_event_inds[valid_mask] event_vals = event_vals[: len(pseudo_event_inds)] np.random.shuffle(event_vals) pseudo_spikes[pseudo_event_inds] = event_vals return pseudo_spikes
[docs] def optimize_kinetics( self, method="direct", fps=20, update_reconstruction=True, max_event_dur_multiplier=4, detection_method="auto", **kwargs, ): """ Universal kinetics optimization with automatic reconstruction update. Single entry point for all optimization methods. Automatically updates instance kinetics (self.t_rise, self.t_off) and optionally re-runs spike detection with new parameters. Parameters ---------- method : str, optional Optimization method. Currently only 'direct' is supported. Default: 'direct' fps : float, optional Sampling rate in frames per second. Default: 20.0 update_reconstruction : bool, optional If True, automatically re-run spike detection with optimized kinetics. This ensures events match the new parameters. Default: True. max_event_dur_multiplier : float, optional Multiplier for calculating max_event_dur when update_reconstruction=True. Formula: max_event_dur = t_rise + multiplier * t_off. Higher values detect longer events but may merge overlapping events. Lower values improve precision but may miss event tails. Recommended range: 3.0-5.0. Default: 4.0 (optimal balance). detection_method : {'auto', 'wavelet', 'threshold'}, optional Method to use for event re-detection when update_reconstruction=True: - 'auto': Use threshold if threshold_events exist, else wavelet (default) - 'wavelet': Always use wavelet detection (slower, more sensitive) - 'threshold': Always use threshold detection (faster, requires high SNR) Default: 'auto' **kwargs : dict Method-specific parameters including: - min_r2 : float, minimum R² for t_off fit quality (default: 0.8). Events with poor exponential fit are rejected. See _optimize_kinetics_direct() for full list. Returns ------- dict Optimization results with keys: - 'optimized': bool, whether optimization succeeded - 't_rise': float, optimized rise time (seconds) - 't_off': float, optimized decay time (seconds) - 'method': str, method used - Additional method-specific metrics Examples -------- >>> # Fast threshold-based workflow (100-500x faster) >>> neuron.detect_events_threshold(n_mad=4.0) >>> result = neuron.optimize_kinetics(method='direct', fps=30) >>> # Auto-detects threshold mode, re-runs threshold detection with optimized kinetics >>> # Explicit threshold mode for iterative refinement >>> neuron.detect_events_threshold(n_mad=4.0) >>> result = neuron.optimize_kinetics( ... method='direct', fps=30, detection_method='threshold' ... ) >>> # Traditional wavelet workflow (slower but more sensitive) >>> neuron.reconstruct_spikes(method='wavelet') >>> result = neuron.optimize_kinetics(method='direct', fps=30) >>> # Uses wavelet detection for re-detection >>> # Skip auto-reconstruction for speed (e.g., in batch processing) >>> result = neuron.optimize_kinetics( ... method='direct', fps=30, update_reconstruction=False ... ) >>> # Later: neuron.reconstruct_spikes(method='wavelet') Notes ----- - Consistently updates self.t_rise/t_off and reconstructs events - Setting update_reconstruction=False allows manual control of reconstruction timing """ if method != "direct": raise ValueError(f"Only 'direct' method is supported, got '{method}'") # Call direct optimization method with FPS-adaptive defaults default_max_forward = int(MAX_FRAMES_FORWARD_SEC * fps) default_max_back = int(MAX_FRAMES_BACK_SEC * fps) result = self._optimize_kinetics_direct( fps=fps, asp=kwargs.get("asp", None), wvt_ridges=kwargs.get("wvt_ridges", None), max_frames_forward=kwargs.get("max_frames_forward", default_max_forward), max_frames_back=kwargs.get("max_frames_back", default_max_back), min_events=kwargs.get("min_events", 5), aggregation=kwargs.get("aggregation", "median"), min_r2=kwargs.get("min_r2", 0.8), ) # Update instance kinetics if optimization succeeded if result.get("optimized", False): self.t_rise = result["t_rise"] * fps # Convert seconds to frames self.t_off = result["t_off"] * fps # Optionally update reconstruction with new kinetics if update_reconstruction: # Determine which detection method to use if detection_method == "auto": # Auto: prefer threshold if previously used, else wavelet use_threshold = ( hasattr(self, "threshold_events") and self.threshold_events is not None and len(self.threshold_events) > 0 ) elif detection_method == "threshold": use_threshold = True elif detection_method == "wavelet": use_threshold = False else: raise ValueError( f"detection_method must be 'auto', 'wavelet', or 'threshold', got '{detection_method}'" ) if use_threshold: # Full threshold reconstruction to update ASP with new kinetics self.reconstruct_spikes( method="threshold", n_mad=kwargs.get("n_mad", 4.0), min_duration_frames=kwargs.get("min_duration_frames", 3), merge_gap_frames=kwargs.get("merge_gap_frames", 2), iterative=kwargs.get("iterative", False), n_iter=kwargs.get("n_iter", 3), adaptive_thresholds=kwargs.get("adaptive_thresholds", True), create_event_regions=True, ) else: # Wavelet detection (slower but more sensitive) # Calculate optimal event duration based on kinetics min_event_dur = 0.5 # seconds, reasonable minimum max_event_dur = result["t_rise"] + max_event_dur_multiplier * result["t_off"] # Re-run spike detection with optimized kinetics # Preserve event regions for wavelet SNR calculation self.reconstruct_spikes( method="wavelet", iterative=kwargs.get("iterative", False), n_iter=kwargs.get("n_iter", 3), adaptive_thresholds=kwargs.get("adaptive_thresholds", True), fps=fps, min_event_dur=min_event_dur, max_event_dur=max_event_dur, event_mask_expansion_sec=kwargs.get("event_mask_expansion_sec", 5.0), create_event_regions=True, ) return result
def _measure_t_off_from_peak(self, signal, peak_idx, fps, max_frames=100, min_r2=0.8): """Measure t_off by forward exponential decay fitting from peak. Parameters ---------- signal : ndarray Calcium signal peak_idx : int Index of peak in signal fps : float Frames per second max_frames : int, optional Maximum frames to look forward. Default: 100. min_r2 : float, optional Minimum R² for fit quality. Events with poor fit (e.g., contaminated by close events) are rejected. Default: 0.8. Returns ------- float or None Estimated t_off in seconds, or None if measurement fails or fit quality is below min_r2. """ decay_start = peak_idx decay_end = min(len(signal), peak_idx + max_frames) # FPS-adaptive thresholds min_decay_frames = int(MIN_DECAY_FRAMES_SEC * fps) min_valid_points = int(MIN_VALID_POINTS_SEC * fps) if decay_end - decay_start < min_decay_frames: return None decay_signal = signal[decay_start:decay_end] if np.max(decay_signal) <= 0: return None decay_signal = decay_signal / np.max(decay_signal) valid = decay_signal > 0.01 if np.sum(valid) < min_valid_points: return None log_y = np.log(decay_signal[valid]) t = np.arange(len(decay_signal))[valid] / fps if len(t) < min_valid_points: return None # Fit and compute R² coeffs = np.polyfit(t, log_y, 1) slope, intercept = coeffs # Check fit quality - reject contaminated events y_pred = slope * t + intercept ss_res = np.sum((log_y - y_pred) ** 2) ss_tot = np.sum((log_y - np.mean(log_y)) ** 2) r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0 if r2 < min_r2: return None tau = -1 / slope if slope < 0 else None if tau is not None: if 0.1 < tau < 30: # Upper bound increased from 10 to handle long decay signals return tau return None def _measure_t_rise_derivative(self, signal, peak_idx, fps, max_frames_back=30): """Measure t_rise using maximum derivative method (RECOMMENDED). This method uses the relationship: dy/dt_max = A/tau for exponential rise. Demonstrated to have 4-33% error vs 36-89% for log-space method. Parameters ---------- signal : ndarray Calcium signal peak_idx : int Index of peak in signal fps : float Frames per second max_frames_back : int, optional Maximum frames to look backward. Default: 30 (1s at 30fps). Returns ------- float or None Estimated t_rise in seconds, or None if measurement fails. """ rise_end = peak_idx + 1 rise_start = max(0, peak_idx - max_frames_back) # FPS-adaptive thresholds min_rise_frames = int(MIN_RISE_FRAMES_SEC * fps) min_valid_points = int(MIN_VALID_POINTS_SEC * fps) savgol_window = max(5, int(SAVGOL_WINDOW_SEC * fps)) # Savitzky-Golay requires odd window length if savgol_window % 2 == 0: savgol_window += 1 if rise_end - rise_start < min_rise_frames: return None rise_signal = signal[rise_start:rise_end] if np.max(rise_signal) <= 0: return None if len(rise_signal) >= min_valid_points: try: # Use FPS-adaptive window length (must be odd and <= len(rise_signal)) actual_window = min(savgol_window, len(rise_signal)) if actual_window % 2 == 0: actual_window -= 1 if actual_window >= 5: # Minimum for polyorder=2 smoothed = savgol_filter(rise_signal, actual_window, 2) else: smoothed = rise_signal except (ValueError, TypeError): smoothed = rise_signal derivative = np.gradient(smoothed) * fps max_deriv_idx = np.argmax(derivative) max_deriv = derivative[max_deriv_idx] peak_val = smoothed[-1] baseline = np.percentile(smoothed[: min(min_valid_points, len(smoothed))], 50) amplitude = peak_val - baseline # Use minimum derivative threshold to avoid division by tiny values if max_deriv > 1e-6 and amplitude > 0: tau = amplitude / max_deriv if 0.01 < tau < 1: return tau return None return None def _optimize_kinetics_direct( self, fps=20, asp=None, wvt_ridges=None, max_frames_forward=100, max_frames_back=30, min_events=5, aggregation="median", min_r2=0.8, ): """Internal: Optimize kinetics by direct measurement from detected events. This method directly measures t_rise and t_off from calcium signal peaks: - t_off: Forward exponential decay fitting (log-space) - t_rise: Derivative-based method (tau = amplitude / max_derivative) Unlike optimize_kinetics(), this method does NOT use iterative optimization. Instead, it measures kinetics directly from each event and aggregates results. Advantages: ----------- - **Fast** - Direct measurement from event shapes without iterative optimization - **Stable** t_rise estimates with low variance across measurements - **No optimization failures** - deterministic measurement - **Better for fast indicators** - derivative method shows superior accuracy Limitations: ------------ - Requires well-separated peaks (works best with single-pass wavelet detection) - Less effective for overlapping events - May need larger sample of events for stable estimates Parameters ---------- fps : float, optional Sampling rate in frames per second. Default: 20.0. asp : array-like or TimeSeries, optional External amplitude spike data. If provided, uses this instead of self.asp. Should be sparse array where asp[i] = amplitude at event positions. Default: None (uses self.asp). wvt_ridges : list of Ridge or SimpleEvent objects, optional External event data (event boundaries). If provided, uses this instead of searching for self.threshold_events or self.wvt_ridges. Objects must have .start and .end attributes. Used for peak finding. Default: None (uses self.threshold_events if available, else self.wvt_ridges). max_frames_forward : int, optional Maximum frames to look forward for t_off measurement. Default: 100. max_frames_back : int, optional Maximum frames to look backward for t_rise measurement. Default: 30. min_events : int, optional Minimum number of successful measurements required. Default: 5. aggregation : {\'median\', \'mean\'}, optional How to aggregate measurements from multiple events. Default: \'median\' (more robust). min_r2 : float, optional Minimum R² for t_off fit quality. Events with poor exponential fit (e.g., contaminated by close events) are rejected. Default: 0.8. Returns ------- dict Results dictionary with keys: - \'optimized\': bool, True only if BOTH t_rise and t_off were successfully measured - \'partially_optimized\': bool, True if exactly one parameter was measured - \'used_defaults\': dict, {\'t_rise\': bool, \'t_off\': bool} indicating which params defaulted - \'t_rise\': float, measured rise time (seconds), or default if measurement failed - \'t_off\': float, measured decay time (seconds), or default if measurement failed - \'t_rise_std\': float, standard deviation of t_rise measurements - \'t_off_std\': float, standard deviation of t_off measurements - \'n_events_used_rise\': int, number of events with successful t_rise measurement - \'n_events_used_off\': int, number of events with successful t_off measurement - \'n_events_detected\': int, total events detected - \'method\': str, always \'direct\' Raises ------ ValueError If no events detected (call reconstruct_spikes() first). If aggregation method not in {\'median\', \'mean\'}. Examples -------- >>> # Fast direct measurement (recommended for initial estimates) >>> neuron = Neuron(cell_id=0, ca=calcium_trace, sp=None, fps=30) >>> neuron.reconstruct_spikes(method=\'wavelet\') >>> result = neuron.optimize_kinetics_direct(fps=30) >>> print(f"Direct: t_rise={result[\'t_rise\']:.3f}s ± {result[\'t_rise_std\']:.3f}s") >>> print(f"Direct: t_off={result[\'t_off\']:.3f}s ± {result[\'t_off_std\']:.3f}s") >>> # Use with external events >>> result = neuron.optimize_kinetics_direct(fps=30, asp=custom_events) Notes ----- - Derivative method for t_rise validated on synthetic data - For iterative refinement, combine with reconstruct_spikes(iterative=True) See Also -------- optimize_kinetics : Main kinetics optimization interface get_kinetics : Cached access to optimization results """ check_positive( fps=fps, max_frames_forward=max_frames_forward, max_frames_back=max_frames_back, min_events=min_events, ) if aggregation not in ("median", "mean"): raise ValueError( f"""aggregation must be \'median\' or \'mean\', got \'{aggregation}\'""" ) default_t_rise = self.default_t_rise / fps default_t_off = self.default_t_off / fps calcium_signal = np.asarray(self.ca.data) # Check for available event data (wavelet ridges or threshold events) if wvt_ridges is not None: ridges = wvt_ridges elif ( hasattr(self, "threshold_events") and self.threshold_events is not None and len(self.threshold_events) > 0 ): ridges = self.threshold_events elif self.wvt_ridges is not None and len(self.wvt_ridges) > 0: ridges = self.wvt_ridges else: return { "optimized": False, "partially_optimized": False, "used_defaults": {"t_rise": True, "t_off": True}, "t_rise": default_t_rise, "t_off": default_t_off, "t_rise_std": 0, "t_off_std": 0, "n_events_used_rise": 0, "n_events_used_off": 0, "n_events_detected": 0, "method": "direct", "error": "No events available. Call reconstruct_spikes() or detect_events_threshold() first.", } n_events_detected = len(ridges) if n_events_detected == 0: return { "optimized": False, "partially_optimized": False, "used_defaults": {"t_rise": True, "t_off": True}, "t_rise": default_t_rise, "t_off": default_t_off, "t_rise_std": 0, "t_off_std": 0, "n_events_used_rise": 0, "n_events_used_off": 0, "n_events_detected": 0, "method": "direct", "error": "No events detected", } event_positions = [] for ridge in ridges: start_idx = int(ridge.start) end_idx = int(ridge.end) if end_idx <= start_idx or start_idx < 0 or end_idx >= len(calcium_signal): continue event_segment = calcium_signal[start_idx : end_idx + 1] peak_offset = np.argmax(event_segment) peak_idx = start_idx + peak_offset event_positions.append(peak_idx) event_positions = np.array(event_positions) if len(event_positions) == 0: return { "optimized": False, "partially_optimized": False, "used_defaults": {"t_rise": True, "t_off": True}, "t_rise": default_t_rise, "t_off": default_t_off, "t_rise_std": 0, "t_off_std": 0, "n_events_used_rise": 0, "n_events_used_off": 0, "n_events_detected": n_events_detected, "method": "direct", "error": "No valid event boundaries found", } t_rise_measurements = [] t_off_measurements = [] for peak_idx in event_positions: t_off = self._measure_t_off_from_peak( calcium_signal, peak_idx, fps, max_frames=max_frames_forward, min_r2=min_r2 ) if t_off is not None: t_off_measurements.append(t_off) t_rise = self._measure_t_rise_derivative( calcium_signal, peak_idx, fps, max_frames_back=max_frames_back ) if t_rise is not None: t_rise_measurements.append(t_rise) if len(t_rise_measurements) < min_events and len(t_off_measurements) < min_events: return { "optimized": False, "partially_optimized": False, "used_defaults": {"t_rise": True, "t_off": True}, "t_rise": default_t_rise, "t_off": default_t_off, "t_rise_std": 0, "t_off_std": 0, "n_events_used_rise": len(t_rise_measurements), "n_events_used_off": len(t_off_measurements), "n_events_detected": n_events_detected, "method": "direct", "error": "Insufficient successful measurements", } if aggregation == "median": t_rise_final = ( np.median(t_rise_measurements) if len(t_rise_measurements) >= min_events else default_t_rise ) t_off_final = ( np.median(t_off_measurements) if len(t_off_measurements) >= min_events else default_t_off ) else: # aggregation == 'mean' t_rise_final = ( np.mean(t_rise_measurements) if len(t_rise_measurements) >= min_events else default_t_rise ) t_off_final = ( np.mean(t_off_measurements) if len(t_off_measurements) >= min_events else default_t_off ) t_rise_std = np.std(t_rise_measurements) if len(t_rise_measurements) > 1 else 0 t_off_std = np.std(t_off_measurements) if len(t_off_measurements) > 1 else 0 # Track which parameters used defaults due to insufficient measurements used_defaults = { "t_rise": len(t_rise_measurements) < min_events, "t_off": len(t_off_measurements) < min_events, } # optimized=True only if BOTH parameters were successfully measured fully_optimized = not any(used_defaults.values()) # partially_optimized=True if exactly one parameter was measured partially_optimized = any(used_defaults.values()) and not all(used_defaults.values()) return { "optimized": fully_optimized, "partially_optimized": partially_optimized, "used_defaults": used_defaults, "t_rise": t_rise_final, "t_off": t_off_final, "t_rise_std": t_rise_std, "t_off_std": t_off_std, "n_events_used_rise": len(t_rise_measurements), "n_events_used_off": len(t_off_measurements), "n_events_detected": n_events_detected, "method": "direct", }
[docs] def detect_events_threshold( self, threshold=None, n_mad=4.0, min_duration_frames=3, merge_gap_frames=2, use_scaled=True, signal=None, ): """Detect calcium events using threshold crossings (fast alternative to wavelet). This method finds event boundaries by detecting when the calcium signal crosses above/below a threshold. Returns SimpleEvent objects compatible with optimize_kinetics(), providing ~100-500x speedup vs wavelet detection. Parameters ---------- threshold : float, optional Absolute threshold value. If None, computed as: median(signal) + n_mad * MAD(signal) where MAD = median absolute deviation (robust noise estimate). n_mad : float, optional Number of MAD units above median for auto-threshold. Only used if threshold=None. Default: 4.0 (robust detection). min_duration_frames : int, optional Minimum event duration in frames. Events shorter than this are discarded. Default: 3 frames. merge_gap_frames : int, optional Merge events separated by fewer than this many frames. Prevents event fragmentation. Default: 2 frames. use_scaled : bool, optional If True, use scaled calcium data (self.ca.scdata). If False, use raw calcium data (self.ca.data). Default: True (recommended for consistent thresholds). signal : ndarray, optional Custom signal to use for detection. If provided, use_scaled is ignored. Used for iterative detection on residual signals. Returns ------- list of SimpleEvent Detected events with .start and .end attributes (frame indices). Empty list if no events detected. Raises ------ AttributeError If calcium data not available or lacks scdata attribute. ValueError If parameters are invalid (negative values, percentile out of range). Examples -------- >>> # Automatic threshold (recommended) >>> neuron = Neuron(cell_id=0, ca=calcium_data, sp=None) >>> events = neuron.detect_events_threshold(n_mad=4.0) >>> len(events) 42 >>> # Then use with optimize_kinetics for fast kinetics estimation >>> neuron.threshold_events = events # Store for optimize_kinetics >>> result = neuron.optimize_kinetics(method='direct', fps=20) >>> # Manual threshold >>> events = neuron.detect_events_threshold(threshold=0.3, use_scaled=True) >>> # More sensitive detection (lower threshold) >>> events = neuron.detect_events_threshold(n_mad=3.0) >>> # Access event properties >>> for event in events[:3]: ... print(f"Event: frames {event.start:.0f}-{event.end:.0f}, duration={event.duration:.0f}") Notes ----- Performance: O(N) complexity vs O(N²) for wavelet detection - Typical speedup: 100-500x faster - Example: 1000 frames: ~0.01s (threshold) vs 1-5s (wavelet) Algorithm: 1. Compute threshold (auto or manual) 2. Find upward crossings (signal goes above threshold) → event starts 3. Find downward crossings (signal goes below threshold) → event ends 4. Filter by minimum duration 5. Merge events with small gaps Comparison with wavelet detection: - Threshold: Fast, simple, good for high SNR data - Wavelet: Slower, more sensitive, better for low SNR or overlapping events The detected events are stored in self.threshold_events for later use. See Also -------- optimize_kinetics : Use detected events for kinetics estimation reconstruct_spikes : Wavelet-based spike detection (slower but more sensitive) """ check_positive( min_duration_frames=min_duration_frames, merge_gap_frames=merge_gap_frames, n_mad=n_mad ) # Get calcium signal (use custom signal if provided, otherwise from self.ca) if signal is not None: signal = np.asarray(signal) elif use_scaled: if not hasattr(self.ca, "scdata"): raise AttributeError( "Scaled calcium data not available. " "Set use_scaled=False to use raw data, or ensure calcium TimeSeries has scdata." ) signal = np.asarray(self.ca.scdata) else: signal = np.asarray(self.ca.data) # Compute threshold if not provided (robust: median + n_mad * MAD) if threshold is None: signal_median = np.median(signal) signal_mad = median_abs_deviation(signal, scale="normal") threshold = signal_median + n_mad * signal_mad # Find threshold crossings above_threshold = signal > threshold # Find transitions: 0→1 = event start, 1→0 = event end diff = np.diff(above_threshold.astype(int)) starts = np.where(diff == 1)[0] + 1 # +1 because diff shifts indices ends = np.where(diff == -1)[0] + 1 # Handle edge cases if above_threshold[0]: starts = np.concatenate([[0], starts]) if above_threshold[-1]: ends = np.concatenate([ends, [len(signal)]]) # Ensure equal number of starts and ends n_events = min(len(starts), len(ends)) starts = starts[:n_events] ends = ends[:n_events] # Filter by minimum duration durations = ends - starts valid = durations >= min_duration_frames starts = starts[valid] ends = ends[valid] # Merge events with small gaps if len(starts) > 1: merged_starts = [starts[0]] merged_ends = [] for i in range(1, len(starts)): gap = starts[i] - ends[i - 1] if gap <= merge_gap_frames: # Merge with previous event (extend end) continue else: # Close previous event and start new one merged_ends.append(ends[i - 1]) merged_starts.append(starts[i]) # Close last event merged_ends.append(ends[-1]) starts = np.array(merged_starts) ends = np.array(merged_ends) # Create SimpleEvent objects events = [SimpleEvent(start=s, end=e) for s, e in zip(starts, ends)] # Store for later use self.threshold_events = events return events
def _calculate_event_r2( calcium_signal, reconstruction, n_mad=4, event_mask=None, wvt_ridges=None, fps=None ): """Calculate R² on event regions. Parameters ---------- calcium_signal : ndarray Original calcium signal reconstruction : ndarray Reconstructed calcium signal n_mad : float, optional Number of MADs above median for event threshold. Default: 4.0. Ignored if event_mask is provided. event_mask : ndarray, optional Boolean or binary mask indicating event regions. If provided, uses this mask directly. wvt_ridges : list, optional List of wavelet ridge objects with start/end attributes. Used to construct event mask if event_mask is None. fps : float, optional Frame rate in Hz. Required if constructing mask from wvt_ridges. Returns ------- float Event R² value (higher is better). Returns NaN if no events detected. Raises ------ ValueError If neither event_mask nor wvt_ridges is provided. """ # Construct event mask if not provided if event_mask is None: if wvt_ridges is not None and len(wvt_ridges) > 0: if fps is None: raise ValueError("fps required to construct event mask from wvt_ridges") # Extract start/end indices from ridges st_inds = [int(ridge.start) for ridge in wvt_ridges if ridge.start >= 0] end_inds = [int(ridge.end) for ridge in wvt_ridges if ridge.end >= 0] if len(st_inds) > 0: from .wavelet_event_detection import events_to_ts_array events_array = events_to_ts_array( len(calcium_signal), [st_inds], [end_inds], fps ) event_mask = events_array[0] else: return np.nan else: raise ValueError( "Either event_mask or wvt_ridges must be provided for event R² calculation. " "Run reconstruct_spikes() with create_event_regions=True to populate neuron.events." ) event_mask = event_mask > 0 if np.sum(event_mask) == 0: return np.nan ca_events = calcium_signal[event_mask] recon_events = reconstruction[event_mask] residuals = ca_events - recon_events ss_residual = np.sum(residuals**2) ss_total = np.sum((ca_events - np.mean(ca_events)) ** 2) if ss_total == 0: return np.nan return 1 - ss_residual / ss_total _calculate_event_r2 = staticmethod(_calculate_event_r2)