Source code for driada.recurrence.population

"""Population-level recurrence graph construction."""

import warnings

import numpy as np
import scipy.sparse as sp


def _reconcile_graph_sizes(items, trim='adaptive', **kwargs):
    """Reconcile different-sized recurrence graphs or sparse matrices.

    Parameters
    ----------
    items : list of RecurrenceGraph or scipy.sparse matrices
        Items to reconcile.  RecurrenceGraph objects are identified by
        having an ``.adj`` attribute.
    trim : {None, 'min', 'adaptive'}
        Strategy for handling different sizes:

        - ``None``: raise ValueError if sizes differ.
        - ``'min'``: trim all items to the smallest common size.
        - ``'adaptive'``: compute per-item embedding loss (= max_size - size),
          remove items whose loss exceeds ``mean + tail_sigma * std``,
          then trim the rest to the new minimum.  Issues a warning when
          items are removed.
    **kwargs
        Extra parameters for the ``'adaptive'`` strategy:

        - ``tail_sigma`` (float, default 3.0): number of standard deviations
          for the outlier threshold.

    Returns
    -------
    items : list
        Reconciled items (same type as input, trimmed copies where needed).
    kept_mask : ndarray of bool
        Boolean mask of length ``len(original_items)``.  True for items
        that were kept (all True when no outliers are removed).
    """
    if not items:
        raise ValueError("items must be non-empty")

    is_rg = [hasattr(it, 'adj') for it in items]
    sizes = np.array([
        it.adj.shape[0] if has_adj else it.shape[0]
        for it, has_adj in zip(items, is_rg)
    ])
    n_items = len(items)
    kept_mask = np.ones(n_items, dtype=bool)

    if len(set(sizes)) == 1:
        return items, kept_mask

    # --- Strategy dispatch ---
    if trim is None:
        raise ValueError(
            f"Items have different sizes ({sorted(set(sizes))}). "
            f"Pass trim='min' or trim='adaptive' to handle this."
        )

    if trim == 'adaptive':
        tail_sigma = kwargs.get('tail_sigma', 3.0)
        max_size = sizes.max()
        losses = max_size - sizes  # 0 for the longest, positive for shorter
        mean_loss = losses.mean()
        std_loss = losses.std()
        if std_loss > 0:
            threshold = mean_loss + tail_sigma * std_loss
            kept_mask = losses <= threshold
            n_removed = n_items - kept_mask.sum()
            if n_removed > 0:
                warnings.warn(
                    f"Adaptive trim: removed {n_removed}/{n_items} items "
                    f"with embedding loss > {threshold:.0f} "
                    f"(mean={mean_loss:.0f}, std={std_loss:.0f}, "
                    f"sigma={tail_sigma}). "
                    f"Removed sizes: {sorted(sizes[~kept_mask])}",
                    stacklevel=3,
                )
            items = [it for it, keep in zip(items, kept_mask) if keep]
            is_rg = [r for r, keep in zip(is_rg, kept_mask) if keep]
            sizes = sizes[kept_mask]

    # trim='min' or post-adaptive: trim all to min size
    min_n = sizes.min()
    from .recurrence_graph import RecurrenceGraph

    trimmed = []
    for it, has_adj in zip(items, is_rg):
        if has_adj:
            if it.adj.shape[0] > min_n:
                trimmed.append(RecurrenceGraph.from_adjacency(
                    it.adj[:min_n, :min_n]))
            else:
                trimmed.append(it)
        else:
            if it.shape[0] > min_n:
                trimmed.append(it[:min_n, :min_n])
            else:
                trimmed.append(it)

    return trimmed, kept_mask


[docs] def population_recurrence_graph(recurrence_graphs, method='joint', threshold=1.0, binarize_threshold=None, trim='adaptive', **trim_kwargs): """Combine per-neuron recurrence graphs into a population graph. Parameters ---------- recurrence_graphs : list of RecurrenceGraph Individual recurrence graphs. method : {'joint', 'mean'} Combination strategy: - ``'joint'``: Thresholded element-wise AND (JRP). A point (i,j) is kept if at least ``threshold`` fraction of neurons recur there. ``threshold=1.0`` is strict AND, ``threshold=0.5`` is majority voting. - ``'mean'``: Average recurrence matrices. Values in [0, 1]. Optionally binarize with ``binarize_threshold``. threshold : float, default=1.0 For method='joint': fraction of neurons that must recur. binarize_threshold : float, optional For method='mean': if given, binarize the averaged matrix. trim : {None, 'min', 'adaptive'}, default='adaptive' How to handle graphs of different sizes. See :func:`_reconcile_graph_sizes` for details. **trim_kwargs Extra parameters passed to :func:`_reconcile_graph_sizes` (e.g., ``tail_sigma=3.0`` for adaptive mode). Returns ------- RecurrenceGraph Population-level recurrence graph (via ``from_adjacency``). Raises ------ ValueError If graphs have different sizes and ``trim`` is None, or method is unknown. """ if not recurrence_graphs: raise ValueError("recurrence_graphs must be non-empty") valid_methods = ('joint', 'mean') if method not in valid_methods: raise ValueError(f"Unknown method '{method}'. Choose from {valid_methods}.") recurrence_graphs, _ = _reconcile_graph_sizes( recurrence_graphs, trim=trim, **trim_kwargs) n_graphs = len(recurrence_graphs) n = recurrence_graphs[0].adj.shape[0] # Sum all adjacency matrices summed = sp.csr_matrix((n, n), dtype=float) for rg in recurrence_graphs: summed = summed + rg.adj.astype(float) if method == 'joint': min_count = threshold * n_graphs summed_coo = summed.tocoo() mask = summed_coo.data >= min_count adj = sp.csr_matrix( (np.ones(mask.sum()), (summed_coo.row[mask], summed_coo.col[mask])), shape=(n, n), ) else: # mean adj = summed / n_graphs if binarize_threshold is not None: adj_coo = adj.tocoo() mask = adj_coo.data >= binarize_threshold adj = sp.csr_matrix( (np.ones(mask.sum()), (adj_coo.row[mask], adj_coo.col[mask])), shape=(n, n), ) from .recurrence_graph import RecurrenceGraph return RecurrenceGraph.from_adjacency(adj)
[docs] def pairwise_jaccard_sparse(matrices, trim='adaptive', trim_to_min=None, **trim_kwargs): """Compute pairwise Jaccard similarity for sparse binary matrices. Uses sparse matrix multiplication to compute all pairwise intersection counts in a single operation, avoiding O(N^2) Python loops. Parameters ---------- matrices : list of scipy.sparse matrices or RecurrenceGraph objects Square binary adjacency matrices. RecurrenceGraph objects are accepted (uses ``.adj``). trim : {None, 'min', 'adaptive'}, default='adaptive' How to handle matrices of different sizes. See :func:`_reconcile_graph_sizes` for details. trim_to_min : bool, optional **Deprecated in 1.2, will be removed in 2.0.** Use ``trim='min'`` instead. If passed, overrides ``trim``. **trim_kwargs Extra parameters passed to :func:`_reconcile_graph_sizes` (e.g., ``tail_sigma=3.0`` for adaptive mode). Returns ------- jaccard : ndarray of shape (N, N) Symmetric matrix of Jaccard indices. Diagonal is 1.0 for non-empty matrices, 0.0 for empty ones. When ``trim='adaptive'`` removes items, N < len(input). kept_mask : ndarray of bool Boolean mask of length ``len(input)``. True for items that were kept after reconciliation. Raises ------ ValueError If list is empty, or if matrices have different shapes and ``trim`` is None. """ if not matrices: raise ValueError("matrices must be non-empty") # Deprecated trim_to_min alias if trim_to_min is not None: warnings.warn( "trim_to_min is deprecated since 1.2 and will be removed in 2.0. " "Use trim='min' or trim='adaptive' instead.", FutureWarning, stacklevel=2, ) trim = 'min' if trim_to_min else None # Accept RecurrenceGraph objects adjs = [m.adj if hasattr(m, 'adj') else m for m in matrices] adjs, kept_mask = _reconcile_graph_sizes(adjs, trim=trim, **trim_kwargs) n_matrices = len(adjs) n = adjs[0].shape[0] n_sq = n * n # Flatten each (n, n) matrix into a row of length n^2, stack into (N, n^2) rows_list = [] for m in adjs: coo = sp.coo_matrix(m) linear_idx = coo.row.astype(np.int64) * n + coo.col.astype(np.int64) rows_list.append(sp.csr_matrix( (np.ones(len(linear_idx)), (np.zeros(len(linear_idx), dtype=int), linear_idx)), shape=(1, n_sq), )) X = sp.vstack(rows_list, format='csr') # X @ X.T gives all pairwise intersection counts intersections = (X @ X.T).toarray().astype(float) sizes = np.diag(intersections).copy() # Jaccard = intersection / (|A| + |B| - intersection) unions = sizes[:, None] + sizes[None, :] - intersections with np.errstate(divide='ignore', invalid='ignore'): jaccard = np.where(unions > 0, intersections / unions, 0.0) return jaccard, kept_mask