Source code for driada.experiment.exp_build

import copy
import os
import os.path
import warnings
import numpy as np
import pickle

from .exp_base import Experiment
from ..information.info_base import TimeSeries, MultiTimeSeries, aggregate_multiple_ts
from ..information.time_series_types import analyze_time_series_type
from ..utils.naming import construct_session_name
from ..utils.output import show_output
from .neuron import DEFAULT_FPS, DEFAULT_T_OFF, DEFAULT_T_RISE
from ..gdrive.download import download_gdrive_data, initialize_iabs_router

# Reserved keys that should not become behavioral features (all lowercase)
# These are neural data or metadata, not behavioral variables
NEURAL_DATA_ALIASES = {"calcium", "activations", "neural_data", "activity", "rates"}
RESERVED_NEURAL_KEYS = NEURAL_DATA_ALIASES | {"spikes", "sp", "asp", "reconstructions"}
RESERVED_METADATA_KEYS = {"_metadata", "_sync_info"}


def _format_feature_subtype(type_info):
    """Format subtype information for verbose logging.

    Parameters
    ----------
    type_info : TimeSeriesType or None
        Type information from a TimeSeries object.

    Returns
    -------
    str
        Formatted subtype string, e.g., "linear", "circular (360°)", "binary".
        Returns empty string if subtype is None.
    """
    if type_info is None or type_info.subtype is None:
        return ""

    subtype_str = type_info.subtype

    if type_info.is_circular and type_info.circular_period is not None:
        period = type_info.circular_period
        if abs(period - 360) < 1:
            subtype_str = "circular (360)"
        elif abs(period - 2 * 3.14159265) < 0.1:
            subtype_str = "circular (2pi)"
        else:
            subtype_str = f"circular ({period:.1f})"

    return subtype_str


[docs] def load_exp_from_aligned_data( data_source, exp_params, data, force_continuous=[], feature_types=None, bad_frames=[], static_features=None, verbose=True, reconstruct_spikes=None, aggregate_features=None, n_jobs=-1, enable_parallelization=True, create_circular_2d=True, ): """Create an Experiment object from aligned neural and behavioral data. Constructs an Experiment instance from pre-aligned calcium imaging data and behavioral variables, automatically determining feature types and filtering out constant or invalid features. Parameters ---------- data_source : str Identifier for the data source (e.g., 'IABS', 'custom'). Used with exp_params to construct the experiment name. exp_params : dict Experiment parameters dictionary. For IABS data source, requires: - 'track': experimental paradigm (e.g., 'linear_track') - 'animal_id': subject identifier - 'session': session identifier For other sources, can contain any metadata for experiment naming. data : dict Dictionary containing aligned data with keys: - 'calcium' or 'Calcium': 2D array of calcium signals (neurons x time) - 'spikes' or 'Spikes': 2D array of spike data (optional) - Other keys: behavioral variables as 1D or 2D arrays. 1D arrays (time,) are treated as single time series; 2D arrays (components, time) are treated as MultiTimeSeries. force_continuous : list, optional **Deprecated.** Use ``feature_types`` instead. List of feature names to force as continuous. Converted to ``feature_types={f: 'continuous'}`` internally if ``feature_types`` is not provided. feature_types : dict[str, str], optional Map of feature names to type strings, overriding auto-detection. See ``TimeSeries._create_type_from_string`` for valid strings. Also acts as a circular whitelist: unlisted auto-circular features are overridden to linear. bad_frames : list, optional List of frame indices to mark as bad/invalid. These frames will be masked in the resulting Experiment object. Useful for removing motion artifacts or recording gaps. static_features : dict, optional Static experimental parameters. Common keys: - 't_rise_sec': calcium rise time (default: 0.25) - 't_off_sec': calcium decay time (default: 2.0) - 'fps': frame rate in Hz (default: 20.0) - Any other experiment-specific constants verbose : bool, default=True Whether to print progress and feature information. reconstruct_spikes : str, bool, or None, default=None **DEPRECATED**: This parameter is deprecated. Load the experiment first, then call exp.reconstruct_all_neurons() separately for better control. If provided (for backward compatibility): - 'wavelet': wavelet-based detection (old batch method) - False/None: no reconstruction (recommended) New workflow (recommended): >>> exp = load_exp_from_aligned_data(data_source, exp_params, data) >>> exp.reconstruct_all_neurons(method='wavelet', n_iter=3) aggregate_features : dict, optional Dictionary mapping tuples of feature keys to combined names. Allows pre-specifying which features should be combined into MultiTimeSeries before Experiment building. This is useful for deterministic data hash generation. Format: {(key1, key2, ...): "combined_name", ...} Example: >>> aggregate_features = { ... ("x", "y"): "position", # Combine x, y into 2D MultiTimeSeries ... ("speed", "direction"): "velocity", ... } The component features remain available as individual features in addition to the combined MultiTimeSeries. n_jobs : int, default=-1 Number of parallel jobs for neuron construction and other parallel operations. Use -1 for all available cores, 1 to disable parallelization. enable_parallelization : bool, default=True Enable parallel processing for neuron construction and hash computation. Set to False to use sequential processing (useful for debugging). create_circular_2d : bool, default=True If True, automatically create `_2d` versions of circular features (detected via type_info.is_circular) as (cos, sin) MultiTimeSeries. Original features are preserved. E.g., 'headdirection' -> also creates 'headdirection_2d'. This improves MI estimation accuracy for circular variables like head direction. Returns ------- Experiment Initialized Experiment object with processed data. Raises ------ TypeError If data or exp_params are not dictionaries. ValueError If data is empty or calcium data is missing. Side Effects ------------ - Prints feature information if verbose=True - Creates deep copy of input data Notes ----- - Features with ≤1 unique non-NaN values are filtered as "garbage" - Feature types (discrete/continuous) are automatically determined by checking if values appear to be categorical (few unique values) or continuous - Case-insensitive key matching for 'calcium' and 'spikes' - Creates a deep copy of input data to avoid modifying the original - The experiment name is constructed using construct_session_name() - Bad frames create a boolean mask; indices beyond data length are ignored - Scalar values (0D arrays) are ignored with a warning - use static_features instead - Non-numeric features (strings, objects) are ignored with a warning - 2D arrays are automatically converted to MultiTimeSeries objects Examples -------- >>> # Basic usage with minimal data >>> np.random.seed(42) # For reproducibility >>> data = { ... 'calcium': np.random.rand(50, 1000), # 50 neurons, 1000 frames ... 'position': np.linspace(0, 100, 1000), # Linear track position ... 'speed': np.random.rand(1000) * 10, # Random speeds ... 'trial_type': np.repeat([0, 1, 0, 1], 250) # Discrete variable ... } >>> exp_params = { ... 'track': 'linear_track', ... 'animal_id': 'mouse01', ... 'session': 'day1' ... } >>> exp = load_exp_from_aligned_data('IABS', exp_params, data, verbose=False) >>> exp.signature 'Exp linear_track_mouse01_day1' >>> sorted(exp.dynamic_features.keys()) ['position', 'speed', 'trial_type'] >>> # Force discrete variable to be continuous >>> exp2 = load_exp_from_aligned_data( ... 'IABS', exp_params, data, ... force_continuous=['trial_type'], ... bad_frames=[10, 11, 12], # Mark frames as bad ... static_features={'fps': 30.0}, # Override default fps ... verbose=False ... ) >>> exp2.static_features['fps'] 30.0 >>> exp2.dynamic_features['trial_type'].discrete # Should be False due to force_continuous False """ # Deprecation warning for reconstruct_spikes parameter if reconstruct_spikes is not None and reconstruct_spikes is not False: warnings.warn( "The 'reconstruct_spikes' parameter is deprecated. " "Load the experiment first, then call exp.reconstruct_all_neurons() separately. " "Example: exp = load_exp_from_aligned_data(...); exp.reconstruct_all_neurons(method='wavelet')", DeprecationWarning, stacklevel=2, ) # Validate inputs if not isinstance(data, dict): raise TypeError(f"data must be a dictionary, got {type(data).__name__}") if not data: raise ValueError("data dictionary cannot be empty") if not isinstance(exp_params, dict): raise TypeError(f"exp_params must be a dictionary, got {type(exp_params).__name__}") expname = construct_session_name(data_source, exp_params) adata = copy.deepcopy(data) key_mapping = {key.lower(): key for key in adata.keys()} if verbose: print(f"Building experiment {expname}...") neural_key = None for alias in NEURAL_DATA_ALIASES: if alias in key_mapping: neural_key = alias break if neural_key is not None: calcium = adata.pop(key_mapping[neural_key]) else: raise ValueError( f"No neural data found. Use one of these keys: {sorted(NEURAL_DATA_ALIASES)}" ) spikes = None if "spikes" in key_mapping: spikes = adata.pop(key_mapping["spikes"]) # Extract asp (optional) asp = None if "asp" in key_mapping: asp = adata.pop(key_mapping["asp"]) # Extract reconstructions (optional) reconstructions = None if "reconstructions" in key_mapping: reconstructions = adata.pop(key_mapping["reconstructions"]) # Extract metadata (merge _sync_info into it if present) metadata = None if "_metadata" in adata: metadata = adata.pop("_metadata") if hasattr(metadata, 'item'): metadata = metadata.item() if "_sync_info" in adata: sync_info = adata.pop("_sync_info") if hasattr(sync_info, 'item'): sync_info = sync_info.item() if metadata is None: metadata = {} metadata['sync_info'] = sync_info # Process dynamic features, handling multidimensional arrays filt_dyn_features = {} # Process feature aggregations first (before individual feature processing) # Note: component features are NOT consumed - they remain available as individual features if aggregate_features: for component_keys, combined_name in aggregate_features.items(): # Validate all component keys exist missing = [k for k in component_keys if k not in adata] if missing: if verbose: warnings.warn(f"Skipping aggregation '{combined_name}': missing keys {missing}") continue # Read component arrays (don't pop - keep them for individual processing) ts_list = [] for i, key in enumerate(component_keys): arr = np.asarray(adata[key]) if arr.ndim != 1: raise ValueError(f"Aggregation component '{key}' must be 1D, got {arr.ndim}D") ts = TimeSeries(arr, discrete=False, name=f"{combined_name}_{i}") ts_list.append(ts) # Create MultiTimeSeries from components (adds noise to break degeneracy) filt_dyn_features[combined_name] = aggregate_multiple_ts(*ts_list, name=combined_name) dyn_features = adata.copy() def is_garbage(vals): """Check if values are constant or all NaN. Parameters ---------- vals : array-like Values to check for validity. Returns ------- bool True if values are all NaN, constant, or empty. Notes ----- Used to filter out uninformative features from dynamic data.""" # Convert to numpy array for consistent handling arr = np.asarray(vals) # Check if empty if arr.size == 0: return True # Check if all NaN or all same value (ignoring NaN) nan_mask = np.isnan(arr) return np.all(nan_mask) or (len(np.unique(arr[~nan_mask])) <= 1) # Process remaining dynamic features # Deprecation bridge: convert force_continuous to feature_types if force_continuous and not feature_types: warnings.warn( "force_continuous is deprecated. Use feature_types={'name': 'linear'} instead.", DeprecationWarning, stacklevel=2, ) feature_types = {f: 'continuous' for f in force_continuous} feature_types = feature_types or {} for f, vals in dyn_features.items(): # Skip reserved keys (case-insensitive for neural keys) if f.lower() in RESERVED_NEURAL_KEYS or f in RESERVED_METADATA_KEYS: continue # Convert to numpy array to check dimensions vals_array = np.asarray(vals) # Skip scalar values with warning if vals_array.ndim == 0: if verbose: print( f"Warning: Ignoring scalar value '{f}' found in NPZ file. " f"Scalar values should be provided via static_features parameter." ) continue # Skip non-numeric features with warning if vals_array.dtype.kind in ["U", "S", "O"]: # Unicode, bytes, or object if verbose: print( f"Warning: Ignoring non-numeric feature '{f}' with dtype {vals_array.dtype}. " f"Only numeric features are supported." ) continue if is_garbage(vals): continue # Handle based on dimensionality if vals_array.ndim == 1: # 1D -> TimeSeries if f in feature_types: filt_dyn_features[f] = TimeSeries(vals_array, ts_type=feature_types[f], name=f) else: # Let TimeSeries auto-detect the type filt_dyn_features[f] = TimeSeries(vals_array, name=f) elif vals_array.ndim == 2: # 2D -> MultiTimeSeries (each row is a component) # This matches Experiment.__init__ behavior ts_list = [ TimeSeries(vals_array[i, :], discrete=False, name=f"{f}_{i}") for i in range(vals_array.shape[0]) ] filt_dyn_features[f] = MultiTimeSeries(ts_list, name=f) else: # Skip features with unsupported dimensions if verbose: print(f"Warning: Skipping feature '{f}' with unsupported {vals_array.ndim}D shape") if verbose: print("behaviour variables:") print() for f, ts in filt_dyn_features.items(): dtype = "discrete" if ts.discrete else "continuous" if isinstance(ts, MultiTimeSeries): type_info = ts.ts_list[0].type_info if ts.ts_list else None subtype_str = _format_feature_subtype(type_info) dim_str = f"multi-dimensional ({ts.n_dim}D)" if subtype_str: print(f"'{f}' {dtype} {dim_str} {subtype_str}") else: print(f"'{f}' {dtype} {dim_str}") else: subtype_str = _format_feature_subtype(ts.type_info) if subtype_str: print(f"'{f}' {dtype} {subtype_str}") else: print(f"'{f}' {dtype}") # check for constant features constfeats = set(dyn_features.keys()) - set(filt_dyn_features.keys()) if len(constfeats) != 0 and verbose: print(f"features {constfeats} dropped as constant or empty") auto_continuous = [fn for fn, ts in filt_dyn_features.items() if not ts.discrete] if verbose: print(f"features {auto_continuous} automatically determined as continuous") print() # Compare forced types with auto-detection and enforce circular whitelist if feature_types: circular_whitelist = {f for f, t in feature_types.items() if t in ('circular', 'phase', 'angle')} for f, ts in list(filt_dyn_features.items()): if not isinstance(ts, TimeSeries) or ts.discrete: continue if f in feature_types: # Warn when forced type disagrees with auto-detection auto_type = analyze_time_series_type(ts.data, name=f) auto_sub = auto_type.subtype or auto_type.primary_type forced_sub = ts.type_info.subtype or ts.type_info.primary_type if auto_sub != forced_sub: warnings.warn( f"Feature '{f}' type overridden: auto-detected " f"'{auto_sub}' (conf={auto_type.confidence:.2f}) " f"-> forced '{forced_sub}'", UserWarning, stacklevel=2, ) elif ts.type_info and ts.type_info.is_circular: # Auto-detected circular but not whitelisted — override to linear warnings.warn( f"Feature '{f}' auto-detected as circular but not in " f"feature_types. Overriding to linear.", UserWarning, stacklevel=2, ) filt_dyn_features[f] = TimeSeries(ts.data, ts_type='linear', name=f) signature = f"Exp {expname}" # set default static experiment features if not provided default_static_features = { "t_rise_sec": DEFAULT_T_RISE, "t_off_sec": DEFAULT_T_OFF, "fps": DEFAULT_FPS, } if static_features is None: static_features = dict() for sf in default_static_features.keys(): if sf not in static_features: static_features.update({sf: default_static_features[sf]}) # Auto-set fps from metadata if not already specified if metadata is not None and 'fps' in metadata: if 'fps' not in static_features or static_features['fps'] == DEFAULT_FPS: static_features['fps'] = metadata['fps'] exp = Experiment( signature, calcium, spikes, exp_params, static_features, filt_dyn_features, reconstruct_spikes=reconstruct_spikes, # bad_frames_mask: True = bad frame to remove, False = good frame to keep bad_frames_mask=np.array([i in bad_frames for i in range(calcium.shape[1])]), verbose=verbose, n_jobs=n_jobs, enable_parallelization=enable_parallelization, asp=asp, reconstructions=reconstructions, metadata=metadata, create_circular_2d=create_circular_2d, ) return exp
[docs] def load_experiment( data_source, exp_params, force_rebuild=False, force_reload=False, via_pydrive=True, gauth=None, root="DRIADA data", exp_path=None, data_path=None, force_continuous=[], feature_types=None, bad_frames=[], static_features=None, reconstruct_spikes="wavelet", save_to_pickle=False, verbose=True, router_source=None, ): """Load or create an Experiment object with automatic caching and cloud support. This function provides a high-level interface for loading experiments with smart caching, automatic cloud data download (for IABS data), and pickle serialization. It first checks for cached experiments, then loads from local data files, and finally downloads from cloud storage if needed. Parameters ---------- data_source : str Data source identifier. 'IABS' enables automatic cloud download. Other sources (e.g., 'MyLab') require data_path parameter pointing to a local NPZ file. exp_params : dict Experiment parameters dictionary. See load_exp_from_aligned_data for required fields based on data_source. force_rebuild : bool, default=False If True, rebuild experiment from data files even if pickle cache exists. The existing pickle is ignored completely. force_reload : bool, default=False If True, re-download data from cloud even if local files exist. Also bypasses pickle cache (similar to force_rebuild). via_pydrive : bool, default=True Use PyDrive for Google Drive access. If False, uses alternative method. gauth : GoogleAuth object, optional Pre-authenticated GoogleAuth object for Drive access. If None, will create new authentication. root : str, default='DRIADA data' Root directory for storing experiments and data. exp_path : str, optional Custom path for experiment pickle file. If None, uses standard naming: {root}/{expname}/Exp {expname}.pickle data_path : str, optional Path to NPZ data file. Required for non-IABS data sources. For IABS, if None, uses standard naming: {root}/{expname}/Aligned data/{expname} syn data.npz force_continuous : list, optional **Deprecated.** See load_exp_from_aligned_data. feature_types : dict[str, str], optional Feature type overrides. See load_exp_from_aligned_data. bad_frames : list, optional Frame indices to mark as bad. See load_exp_from_aligned_data. static_features : dict, optional Static experimental parameters. See load_exp_from_aligned_data. reconstruct_spikes : str or bool, default='wavelet' Spike reconstruction method. See load_exp_from_aligned_data. save_to_pickle : bool, default=False Whether to save the experiment to pickle after creation. verbose : bool, default=True Print progress messages. router_source : str, pandas.DataFrame, or None, optional Source of the router data for IABS experiments: - None: Downloads from URL in config.py (default behavior) - str: Direct Google Sheets export URL - pandas.DataFrame: Pre-loaded router DataFrame Only used when data_source='IABS' and downloading from cloud. Returns ------- tuple exp : Experiment The loaded or created Experiment object. load_log : list or None Cloud download log if data was downloaded, None otherwise. Always None for local loads or pickle loads. Raises ------ ValueError If root exists but is not a directory. If data_source is not 'IABS' and no data_path provided. FileNotFoundError If data file not found and cannot be downloaded. Side Effects ------------ - Creates root directory if it doesn't exist - Creates experiment subdirectory structure - Downloads data from cloud for IABS source (if needed) - Saves pickle file if save_to_pickle=True and building from data - Prints progress messages if verbose=True Notes ----- Loading priority: 1. If pickle exists and not force_rebuild/reload: load from pickle 2. If local data exists and not force_reload: load from data file 3. If IABS source: attempt cloud download 4. Otherwise: raise error For IABS data, expects cloud structure with 'Aligned data' containing npz files with calcium and behavioral data. The function returns a tuple (exp, load_log) to maintain backward compatibility, even though load_log is often None. Examples -------- >>> # Load IABS data with custom router URL >>> url = "https://docs.google.com/spreadsheets/d/.../export?format=xlsx" >>> exp, _ = load_experiment( # doctest: +SKIP ... 'IABS', ... {'track': 'linear', 'animal_id': 'CA1_01', 'session': '1'}, ... router_source=url, ... verbose=False ... ) >>> # Load external lab data from NPZ file >>> import tempfile >>> import numpy as np >>> >>> # Create test data file >>> with tempfile.NamedTemporaryFile(delete=False, suffix='.npz') as f: ... temp_data = f.name >>> test_data = { ... 'calcium': np.random.rand(30, 500), ... 'position': np.random.rand(500) * 100 ... } >>> np.savez(temp_data, **test_data) >>> >>> # Load from local file >>> exp, _ = load_experiment( ... 'MyLab', ... {'name': 'test_exp'}, ... data_path=temp_data, ... verbose=False ... ) >>> exp.signature 'Exp test_exp' >>> exp.n_cells 30 >>> >>> # Force rebuild even if pickle exists >>> with tempfile.TemporaryDirectory() as tmpdir: ... exp2, _ = load_experiment( ... 'MyLab', ... {'name': 'rebuild_test'}, ... data_path=temp_data, ... root=tmpdir, ... force_rebuild=True, ... save_to_pickle=True, ... verbose=False ... ) >>> exp2.n_cells 30 >>> >>> # Cleanup >>> import os >>> os.unlink(temp_data)""" if os.path.exists(root) and not os.path.isdir(root): raise ValueError("Root must be a folder!") os.makedirs(root, exist_ok=True) if exp_path is None: expname = construct_session_name(data_source, exp_params) exp_path = os.path.join(root, expname, f"Exp {expname}.pickle") if os.path.exists(exp_path) and not force_rebuild and not force_reload: Exp = load_exp_from_pickle(exp_path, verbose=verbose) return Exp, None else: if data_source == "IABS": if data_path is None: data_path = os.path.join(root, expname, "Aligned data", f"{expname} syn data.npz") if verbose: print(f"Path to data: {data_path}") data_exists = os.path.exists(data_path) if verbose: if data_exists: print("Aligned data for experiment construction found successfully") else: print("Failed to locate aligned data for experiment construction") if force_reload or not data_exists: if verbose: print("Loading data from cloud storage...") data_router, data_pieces = initialize_iabs_router( root=root, router_source=router_source ) success, load_log = download_gdrive_data( data_router, expname, data_pieces=["Aligned data"], via_pydrive=via_pydrive, tdir=root, gauth=gauth, ) if not success: print("=========== BEGINNING OF LOADING LOG ============") show_output(load_log) print("=========== END OF LOADING LOG ============") raise FileNotFoundError(f"Cannot download {expname}, see loading log above") else: load_log = None aligned_data = dict(np.load(data_path)) Exp = load_exp_from_aligned_data( data_source, exp_params, aligned_data, force_continuous=force_continuous, feature_types=feature_types, static_features=static_features, verbose=verbose, bad_frames=bad_frames, reconstruct_spikes=reconstruct_spikes, ) if save_to_pickle: save_exp_to_pickle(Exp, exp_path, verbose=verbose) return Exp, load_log else: # Support for external (non-IABS) data sources loading from local files if data_path is None: raise ValueError( f"For data source '{data_source}', you must provide the 'data_path' parameter " "pointing to your NPZ data file." ) if not os.path.exists(data_path): raise FileNotFoundError(f"Data file not found: {data_path}") if verbose: print(f"Loading data from: {data_path}") # Load the NPZ file try: aligned_data = dict(np.load(data_path, allow_pickle=True)) except Exception as e: raise ValueError(f"Failed to load NPZ file: {e}") # Check for required neural data key if not any(k in aligned_data for k in NEURAL_DATA_ALIASES): raise ValueError( f"NPZ file must contain neural data under one of: {sorted(NEURAL_DATA_ALIASES)}" ) # Create experiment using the existing function Exp = load_exp_from_aligned_data( data_source, exp_params, aligned_data, force_continuous=force_continuous, feature_types=feature_types, static_features=static_features, verbose=verbose, bad_frames=bad_frames, reconstruct_spikes=reconstruct_spikes, ) # Save to pickle if requested if save_to_pickle: # Create experiment name and path if not provided if exp_path is None: expname = construct_session_name(data_source, exp_params) # Create a reasonable default path exp_dir = os.path.join(root, data_source, expname) os.makedirs(exp_dir, exist_ok=True) exp_path = os.path.join(exp_dir, f"Exp {expname}.pickle") save_exp_to_pickle(Exp, exp_path, verbose=verbose) # No load_log for external data sources return Exp, None
[docs] def save_exp_to_pickle(exp, path, verbose=True): """Save an Experiment object to a pickle file. Parameters ---------- exp : Experiment The Experiment object to save. path : str File path where the pickle will be saved. verbose : bool, default=True Whether to print save confirmation. Raises ------ PermissionError If no write permission for the path. OSError If path is invalid or other OS-related errors. Examples -------- >>> # Create a test experiment >>> import tempfile >>> import os >>> from driada.experiment import load_demo_experiment >>> exp = load_demo_experiment(verbose=False) >>> >>> # Save experiment to temporary file >>> with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as f: ... temp_path = f.name >>> save_exp_to_pickle(exp, temp_path) # doctest: +ELLIPSIS Experiment Exp demo saved to ... >>> # Save without verbose output >>> save_exp_to_pickle(exp, temp_path, verbose=False) >>> >>> # Cleanup >>> os.unlink(temp_path) Notes ----- Uses Python's pickle module with default protocol. Creates parent directories if they don't exist.""" # Create parent directories if they don't exist parent_dir = os.path.dirname(path) if parent_dir: os.makedirs(parent_dir, exist_ok=True) with open(path, "wb") as f: pickle.dump(exp, f) if verbose: print(f"Experiment {exp.signature} saved to {path}\n")
[docs] def load_exp_from_pickle(path, verbose=True): """Load an Experiment object from a pickle file. Parameters ---------- path : str Path to the pickle file. verbose : bool, default=True Whether to print load confirmation. Returns ------- Experiment The loaded Experiment object. Raises ------ FileNotFoundError If the pickle file doesn't exist. PermissionError If no read permission for the file. OSError If path is invalid or other OS-related errors. Examples -------- >>> # Create and save a test experiment first >>> import tempfile >>> from driada.experiment import load_demo_experiment >>> test_exp = load_demo_experiment(verbose=False) >>> with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as f: ... temp_path = f.name >>> save_exp_to_pickle(test_exp, temp_path, verbose=False) >>> >>> # Load experiment from file >>> exp = load_exp_from_pickle(temp_path) # doctest: +ELLIPSIS Experiment Exp demo loaded from ... >>> # Load without verbose output >>> exp = load_exp_from_pickle(temp_path, verbose=False) >>> exp.signature 'Exp demo' >>> >>> # Cleanup >>> import os >>> os.unlink(temp_path) Notes ----- Uses Python's pickle module for deserialization. Prints experiment signature upon successful load if verbose=True.""" with open(path, "rb") as f: exp = pickle.load( f, ) if verbose: print(f"Experiment {exp.signature} loaded from {path}\n") # Backward compatibility: Assign names to unnamed dynamic features from ..information.info_base import MultiTimeSeries if hasattr(exp, 'dynamic_features'): for feat_id, feat_obj in exp.dynamic_features.items(): if hasattr(feat_obj, 'name') and (feat_obj.name is None or feat_obj.name == ''): feat_obj.name = str(feat_id) # Use feature key as name # For MultiTimeSeries, also name components if they lack names if isinstance(feat_obj, MultiTimeSeries): if hasattr(feat_obj, 'ts_list'): for i, component in enumerate(feat_obj.ts_list): if not hasattr(component, 'name') or component.name is None or component.name == '': component.name = f"{feat_id}_{i}" # Backward compatibility: Assign names to neurons if missing if hasattr(exp, 'neurons'): for i, neuron in enumerate(exp.neurons): if hasattr(neuron, 'ca') and neuron.ca is not None: if not hasattr(neuron.ca, 'name') or neuron.ca.name is None or neuron.ca.name == '': neuron.ca.name = f"neuron_{i}_ca" if hasattr(neuron, 'sp') and neuron.sp is not None: if not hasattr(neuron.sp, 'name') or neuron.sp.name is None or neuron.sp.name == '': neuron.sp.name = f"neuron_{i}_sp" if hasattr(neuron, 'asp') and neuron.asp is not None: if not hasattr(neuron.asp, 'name') or neuron.asp.name is None or neuron.asp.name == '': neuron.asp.name = f"neuron_{i}_asp" return exp
def load_demo_experiment(name="demo", verbose=False): """Load a demonstration experiment for documentation and testing. This is a convenience function for loading sample data in documentation examples and tests. It loads a synthetically generated calcium imaging dataset with behavioral data. Parameters ---------- name : str, default='demo' Name identifier for the demo experiment. This becomes part of the experiment's signature. Common values: - 'demo': Basic demonstration - 'test': For unit tests - Any descriptive name for specific examples verbose : bool, default=False Whether to print loading messages. Returns ------- Experiment A loaded Experiment object with: - 50 neurons - 10000 time points - Sample behavioral features (position, speed, etc.) - No spike reconstruction (for speed) Examples -------- >>> from driada.experiment import load_demo_experiment >>> >>> # Basic usage >>> exp = load_demo_experiment() >>> print(f"Loaded {exp.n_cells} neurons, {exp.n_frames} frames") Loaded 50 neurons, 10000 frames >>> >>> # With custom name >>> exp = load_demo_experiment('pca_analysis') >>> print(exp.signature) Exp pca_analysis >>> >>> # Access data >>> calcium_data = exp.calcium.data # (50, 10000) array >>> position = exp.position # MultiTimeSeries with x,y coordinates Notes ----- The demo data is located at 'examples/example_data/sample_recording.npz' relative to the DRIADA installation directory. See Also -------- load_experiment : Full experiment loading with all options ~driada.experiment.synthetic.generators.generate_synthetic_exp : Generate synthetic data with custom properties """ exp, _ = load_experiment( "MyLab", {"name": name}, data_path="examples/example_data/sample_recording.npz", reconstruct_spikes=False, verbose=verbose, save_to_pickle=False, ) return exp