Source code for driada.information.info_utils

import numpy as np
from numba import njit
from ..utils.jit import conditional_njit


@conditional_njit
def py_fast_digamma_arr(data):
    """Compute digamma function for an array of values using fast approximation.

    This is a JIT-compiled version that processes arrays efficiently.
    Uses a series expansion approximation that is accurate for x > 5.

    Parameters
    ----------
    data : ndarray
        Input array of positive values. All values must be > 0.

    Returns
    -------
    ndarray
        Array of digamma values corresponding to input data.

    Raises
    ------
    None
        Invalid inputs (x <= 0) return NaN instead of raising exceptions
        due to numba JIT compilation constraints.

    Notes
    -----
    The algorithm uses a recurrence relation to shift x to the range
    where the series expansion is accurate (x > 5), then applies
    an asymptotic expansion with correction terms.

    For x <= 0, the function returns NaN to avoid infinite loops."""
    res = np.zeros(len(data))
    for i, x in enumerate(data):
        # Handle non-positive values to avoid infinite loop
        if x <= 0:
            res[i] = np.nan
            continue
        r = 0
        while x <= 5:
            r -= 1 / x
            x += 1
        f = 1 / (x * x)
        t = f * (
            -1 / 12.0
            + f
            * (
                1 / 120.0
                + f
                * (
                    -1 / 252.0
                    + f
                    * (
                        1 / 240.0
                        + f
                        * (-1 / 132.0 + f * (691 / 32760.0 + f * (-1 / 12.0 + f * 3617 / 8160.0)))
                    )
                )
            )
        )

        res[i] = r + np.log(x) - 0.5 / x + t

    return res


@conditional_njit
def py_fast_digamma(x):
    """Compute digamma function for a single value using fast approximation.

    This is a JIT-compiled scalar version of the digamma (psi) function.
    Uses a series expansion approximation that is accurate for x > 5.

    Parameters
    ----------
    x : float
        Input value. Must be positive (x > 0).

    Returns
    -------
    float
        The digamma function value psi(x).

    Raises
    ------
    None
        Invalid inputs (x <= 0) return NaN instead of raising exceptions
        due to numba JIT compilation constraints.

    Notes
    -----
    The digamma function is the logarithmic derivative of the gamma function:
    psi(x) = d/dx log(Gamma(x)) = Gamma'(x) / Gamma(x)

    The algorithm uses:
    1. Recurrence relation psi(x) = psi(x+1) - 1/x to shift to x > 5
    2. Asymptotic expansion for large x with Bernoulli number corrections

    For x <= 0, the function returns NaN to avoid infinite loops."""
    # Handle non-positive values to avoid infinite loop
    if x <= 0:
        return np.nan

    r = 0
    x = float(x)  # Ensure float type
    while x <= 5:
        r -= 1 / x
        x += 1
    f = 1 / (x * x)
    t = f * (
        -1 / 12.0
        + f
        * (
            1 / 120.0
            + f
            * (
                -1 / 252.0
                + f
                * (
                    1 / 240.0
                    + f * (-1 / 132.0 + f * (691 / 32760.0 + f * (-1 / 12.0 + f * 3617 / 8160.0)))
                )
            )
        )
    )

    res = r + np.log(x) - 0.5 / x + t
    return res


[docs] def binary_mi_score(contingency): """Calculate mutual information for discrete variables from contingency table. Computes the mutual information between two discrete random variables based on their joint probability distribution represented as a contingency table. Parameters ---------- contingency : ndarray of shape (n_classes_x, n_classes_y) Contingency table where element [i, j] contains the count of samples with x=i and y=j. Must contain non-negative values. Returns ------- float Mutual information score in nats (natural log base). Returns 0.0 if either variable has only one class. Raises ------ ValueError If contingency table has wrong dimensions or contains negative values. TypeError If contingency is not array-like or contains non-numeric values. Notes ----- The mutual information is calculated as: MI(X,Y) = sum_ij P(x_i, y_j) * log(P(x_i, y_j) / (P(x_i) * P(y_j))) This implementation: - Handles sparse contingency tables efficiently by only computing over non-zero entries - Returns 0 for degenerate cases (single cluster) - Clips negative values due to numerical errors to 0""" # Input validation contingency = np.asarray(contingency) if contingency.ndim != 2: raise ValueError(f"Contingency table must be 2D, got {contingency.ndim}D") if np.any(contingency < 0): raise ValueError("Contingency table cannot contain negative values") nzx, nzy = np.nonzero(contingency) nz_val = contingency[nzx, nzy] contingency_sum = contingency.sum() pi = np.ravel(contingency.sum(axis=1)) pj = np.ravel(contingency.sum(axis=0)) # Since MI <= min(H(X), H(Y)), any labelling with zero entropy, i.e. containing a # single cluster, implies MI = 0 if pi.size == 1 or pj.size == 1: return 0.0 log_contingency_nm = np.log(nz_val) contingency_nm = nz_val / contingency_sum # Don't need to calculate the full outer product, just for non-zeroes outer = pi.take(nzx).astype(np.int64, copy=False) * pj.take(nzy).astype(np.int64, copy=False) log_outer = -np.log(outer) + np.log(pi.sum()) + np.log(pj.sum()) mi = ( contingency_nm * (log_contingency_nm - np.log(contingency_sum)) + contingency_nm * log_outer ) mi = np.where(np.abs(mi) < np.finfo(mi.dtype).eps, 0.0, mi) return np.clip(mi.sum(), 0.0, None)