"""K-nearest neighbor estimators for mutual information and entropy.
This module implements the Kraskov-Stögbauer-Grassberger (KSG) estimator
and related k-NN based information theoretic measures.
Important Notes:
- The Local Non-uniformity Correction (LNC) can be unstable when k <= d
(where k is the number of neighbors and d is the total dimensionality).
In such cases, LNC is automatically disabled to prevent numerical errors.
- For best results, use k > d, with k >= d+2 recommended.
- Common good values: k=5 for d<=3, k=10 for d=4-6, k=20 for d>6.
Credits:
Original implementation by Greg Ver Steeg
http://www.isi.edu/~gregv/npeet.html
References:
Kraskov, A., Stögbauer, H., & Grassberger, P. (2004).
Estimating mutual information. Physical Review E, 69(6), 066138.
"""
import numpy as np
import numpy.linalg as la
from numpy import log
from sklearn.neighbors import BallTree, KDTree
from .info_utils import py_fast_digamma
DEFAULT_NN = 5
# UTILITY FUNCTIONS
# Alpha selection for LNC correction based on k and dimensionality
# Values from https://github.com/BiuBiuBiLL/NPEET_LNC/blob/master/alpha.xlsx
# Alpha lookup table: (k, d) -> alpha
# k: number of nearest neighbors
# d: dimensionality
ALPHA_LNC_TABLE = {
# k=2
(2, 3): 0.182224,
(2, 4): 0.284370,
(2, 5): 0.372004,
(2, 6): 0.442894,
(2, 7): 0.503244,
(2, 8): 0.554523,
(2, 9): 0.594569,
(2, 10): 0.630903,
(2, 11): 0.660295,
(2, 12): 0.689290,
(2, 13): 0.711052,
(2, 14): 0.735075,
(2, 15): 0.751908,
(2, 16): 0.767809,
(2, 17): 0.782448,
(2, 18): 0.795362,
(2, 19): 0.806728,
(2, 20): 0.817252,
# k=3
(3, 4): 0.077830,
(3, 5): 0.167277,
(3, 6): 0.250141,
(3, 7): 0.320280,
(3, 8): 0.384474,
(3, 9): 0.441996,
(3, 10): 0.489972,
(3, 11): 0.532178,
(3, 12): 0.568561,
(3, 13): 0.603990,
(3, 14): 0.636593,
(3, 15): 0.660156,
(3, 16): 0.683954,
(3, 17): 0.706157,
(3, 18): 0.724844,
(3, 19): 0.743606,
(3, 20): 0.757283,
# k=5
(5, 6): 0.023953,
(5, 7): 0.067077,
(5, 8): 0.123341,
(5, 9): 0.180215,
(5, 10): 0.239442,
(5, 11): 0.297637,
(5, 12): 0.351355,
(5, 13): 0.404194,
(5, 14): 0.451739,
(5, 15): 0.498458,
(5, 16): 0.538889,
(5, 17): 0.578158,
(5, 18): 0.614937,
(5, 19): 0.651598,
(5, 20): 0.679500,
# k=10
(10, 11): 0.003734,
(10, 12): 0.014748,
(10, 13): 0.034749,
(10, 14): 0.063109,
(10, 15): 0.100471,
(10, 16): 0.147694,
(10, 17): 0.200196,
(10, 18): 0.261374,
(10, 19): 0.325363,
(10, 20): 0.398082,
}
def get_lnc_alpha(k, d):
"""Get optimal alpha value for LNC correction based on k and dimensionality.
Parameters
----------
k : int
Number of nearest neighbors. Must be positive.
d : int
Dimensionality of the data. Must be positive.
Returns
-------
float
Alpha value for LNC correction. Returns 0.25 as default if no
suitable value found in the lookup table.
Raises
------
ValueError
If k or d are not positive integers.
Notes
-----
Values are based on the lookup table from:
https://github.com/BiuBiuBiLL/NPEET_LNC
For (k, d) pairs not in the table:
- If k is not in {2, 3, 5, 10}, uses nearest available k
- If d is outside available range, uses nearest available d"""
# Validate inputs
if not isinstance(k, (int, np.integer)) or k <= 0:
raise ValueError(f"k must be a positive integer, got {k}")
if not isinstance(d, (int, np.integer)) or d <= 0:
raise ValueError(f"d must be a positive integer, got {d}")
# Get available k values
available_k = sorted(set(k_val for k_val, _ in ALPHA_LNC_TABLE.keys()))
# Find closest k
if k in available_k:
k_use = k
else:
# Find nearest k
k_use = min(available_k, key=lambda x: abs(x - k))
# Get available d values for this k
available_d = sorted([d_val for k_val, d_val in ALPHA_LNC_TABLE.keys() if k_val == k_use])
if not available_d:
# No data for this k, use default
return 0.25 # Default from original implementation
# Find appropriate d
if d <= available_d[0]:
# Use smallest available d
d_use = available_d[0]
elif d >= available_d[-1]:
# Use largest available d
d_use = available_d[-1]
elif d in available_d:
# Exact match
d_use = d
else:
# Interpolate between adjacent values
d_lower = max(d_val for d_val in available_d if d_val < d)
d_upper = min(d_val for d_val in available_d if d_val > d)
alpha_lower = ALPHA_LNC_TABLE.get((k_use, d_lower), 0)
alpha_upper = ALPHA_LNC_TABLE.get((k_use, d_upper), 0)
# Linear interpolation
weight = (d - d_lower) / (d_upper - d_lower)
return alpha_lower + weight * (alpha_upper - alpha_lower)
return ALPHA_LNC_TABLE.get((k_use, d_use), 0.25)
def add_noise(x, ampl=1e-10):
"""Add small random noise to data to break degeneracy.
When multiple data points have identical values, k-nearest neighbor
algorithms can become unstable. Adding small random noise helps break
these ties without significantly affecting the mutual information estimate.
Parameters
----------
x : numpy.ndarray
Input data array to add noise to. Must contain finite values.
ampl : float, optional
Amplitude of the noise to add. Default is 1e-10, which is small
enough to not affect the MI estimate but large enough to break ties.
Must be non-negative.
Returns
-------
numpy.ndarray
Input data with small random noise added. Same shape as input.
Noise is uniformly distributed in [0, ampl).
Raises
------
ValueError
If ampl is negative or if x contains non-finite values.
Notes
-----
The noise is uniformly distributed in [0, ampl). This is a standard
technique in KSG mutual information estimation to handle degenerate
cases where many points have identical coordinates.
Uses numpy's global random state. For reproducible results, set the
random seed before calling this function.
References
----------
Kraskov, A., Stögbauer, H., & Grassberger, P. (2004). Estimating mutual
information. Physical Review E, 69(6), 066138."""
# Validate inputs
x = np.asarray(x)
if ampl < 0:
raise ValueError(f"ampl must be non-negative, got {ampl}")
if not np.all(np.isfinite(x)):
raise ValueError("x must contain only finite values")
# small noise to break degeneracy, see doc.
return x + ampl * np.random.random_sample(x.shape)
def query_neighbors(tree, x, k):
"""Query k-th nearest neighbor distances for each point.
Parameters
----------
tree : BallTree or KDTree
Pre-built spatial index tree for efficient neighbor queries.
x : ndarray of shape (n_samples, n_features)
Query points to find neighbors for.
k : int
Which neighbor distance to return (k-th nearest). Must be positive.
Returns
-------
ndarray of shape (n_samples,)
Distance to the k-th nearest neighbor for each query point.
Note: k+1 is queried to exclude the point itself.
Raises
------
ValueError
If k is not a positive integer."""
# Validate inputs
if not isinstance(k, (int, np.integer)) or k <= 0:
raise ValueError(f"k must be a positive integer, got {k}")
x = np.asarray(x)
if x.ndim == 1:
x = x.reshape(-1, 1)
# return tree.query(x, k=k+1, breadth_first = False)[0][:, k]
return tree.query(x, k=k + 1)[0][:, k]
def _count_neighbors_single(tree, x, radii, ind):
"""Count neighbors within radius for a single point (legacy function).
Parameters
----------
tree : BallTree or KDTree
Spatial index tree.
x : ndarray
All data points.
radii : ndarray
Search radius for each point.
ind : int
Index of the point to query. Must be valid index into x.
Returns
-------
int
Number of neighbors within radius (excluding self).
Raises
------
ValueError
If ind is not a valid index or if radii is too short.
Notes
-----
This is a legacy single-point implementation. Use count_neighbors
for efficient vectorized computation."""
# Validate inputs
x = np.asarray(x)
radii = np.asarray(radii)
if not isinstance(ind, (int, np.integer)) or ind < 0 or ind >= len(x):
raise ValueError(f"ind must be valid index in [0, {len(x)}), got {ind}")
if ind >= len(radii):
raise ValueError(f"radii must have at least {ind+1} elements")
dists, indices = tree.query(x[ind : ind + 1], k=DEFAULT_NN, distance_upper_bound=radii[ind])
# Fixed: subtract 1 for self-exclusion, not 2
return len(np.unique(indices[0])) - 1
def count_neighbors(tree, x, radii):
"""Count neighbors within given radius for each point.
Parameters
----------
tree : BallTree or KDTree
Pre-built spatial index tree for efficient radius queries.
x : ndarray of shape (n_samples, n_features)
Query points.
radii : ndarray of shape (n_samples,)
Search radius for each query point.
Returns
-------
ndarray of shape (n_samples,)
Number of neighbors within radius for each point (including self).
Raises
------
ValueError
If x and radii have different lengths.
Notes
-----
Uses efficient vectorized radius query. The count includes the
query point itself, so subtract 1 if self-exclusion is needed."""
# Validate inputs
x = np.asarray(x)
radii = np.asarray(radii)
if x.ndim == 1:
x = x.reshape(-1, 1)
if len(x) != len(radii):
raise ValueError(f"x and radii must have same length, got {len(x)} and {len(radii)}")
return tree.query_radius(x, radii, count_only=True)
def build_tree(points, lf=5):
"""Build spatial index tree for k-NN queries.
Automatically selects between KDTree and BallTree based on
dimensionality for optimal performance.
Parameters
----------
points : ndarray of shape (n_samples, n_features)
Data points to index. Must have at least 1 sample and 1 feature.
lf : int, optional
Leaf size parameter for tree construction. Smaller values
lead to deeper trees with faster queries but slower construction.
Default is 5. Must be positive.
Returns
-------
BallTree or KDTree
Spatial index tree using Chebyshev (max) metric.
Raises
------
ValueError
If points is not 2D, is empty, or if lf is not positive.
Notes
-----
- Uses BallTree for high dimensions (>=20) as KDTree performance
degrades exponentially with dimension.
- Chebyshev metric is used for compatibility with KSG estimator."""
# Validate inputs
points = np.asarray(points)
if points.ndim != 2:
raise ValueError(f"points must be 2D array, got shape {points.shape}")
if points.size == 0:
raise ValueError("points cannot be empty")
if not isinstance(lf, (int, np.integer)) or lf <= 0:
raise ValueError(f"lf must be positive integer, got {lf}")
if points.shape[1] >= 20:
return BallTree(points, metric="chebyshev")
return KDTree(points, metric="chebyshev", leaf_size=lf)
# return KDTree(points, leafsize = lf)
# return KDTree(points, copy_data=True, leafsize = 5)
def avgdigamma(points, dvec, lf=30, tree=None):
"""Compute average digamma of neighbor counts within given radii.
Used in KSG mutual information estimation to compute the average
logarithmic correction term based on neighbor counts in marginal spaces.
Parameters
----------
points : ndarray of shape (n_samples, n_features)
Data points in the marginal space.
dvec : ndarray of shape (n_samples,)
Distance to k-th neighbor in the joint space, used as radius
for counting neighbors in this marginal space.
lf : int, optional
Leaf size for tree construction if tree not provided. Default is 30.
tree : BallTree or KDTree, optional
Pre-built tree. If None, a new tree is constructed.
Returns
-------
float
Average of digamma(neighbor_count) across all points.
Raises
------
ValueError
If inputs have incompatible shapes or invalid values.
Exception
If more than 1% of points have no neighbors within their radius,
indicating potential issues with the data or parameters.
Notes
-----
- Subtracts small epsilon (1e-15) from radii to handle boundary cases
- Points with zero neighbors are assigned count of 0.5 to avoid
numerical issues with digamma function"""
# Validate inputs
points = np.asarray(points)
dvec = np.asarray(dvec)
if points.ndim != 2:
raise ValueError(f"points must be 2D array, got shape {points.shape}")
if dvec.ndim != 1:
raise ValueError(f"dvec must be 1D array, got shape {dvec.shape}")
if len(points) != len(dvec):
raise ValueError(
f"points and dvec must have same length, got {len(points)} and {len(dvec)}"
)
if not isinstance(lf, (int, np.integer)) or lf <= 0:
raise ValueError(f"lf must be positive integer, got {lf}")
# This part finds number of neighbors in some radius in the marginal space
# returns expectation value of <psi(nx)>
if tree is None:
tree = build_tree(points, lf=lf)
# Create copy to avoid modifying input
dvec = dvec - 1e-15
num_points = count_neighbors(tree, points, dvec)
num_points = num_points.astype(float)
zero_inds = np.where(num_points == 0)[0]
if 1.0 * len(zero_inds) / len(num_points) > 0.01:
raise Exception("No neighbours in more than 1% points, check input!")
else:
if len(zero_inds) != 0:
num_points[zero_inds] = 0.5
# inf_inds = np.where(digamma(num_points) == -np.inf)
# print(num_points[inf_inds])
digammas = list(map(py_fast_digamma, num_points))
return np.mean(digammas)
# CONTINUOUS ESTIMATORS
[docs]
def nonparam_entropy_c(x, k=DEFAULT_NN, base=np.e):
"""The classic Kozachenko-Leonenko k-nearest neighbor continuous entropy estimator.
Estimates differential entropy for continuous variables using k-nearest
neighbor distances. This is the foundation for KSG mutual information
estimation.
Parameters
----------
x : array-like
Continuous data, shape (n_samples,) for 1D or (n_samples, n_features)
for multivariate. Each row is a sample, columns are features.
k : int, default=5
Number of nearest neighbors to use. Common values:
- k = 4-5 for most applications (optimal bias-variance tradeoff)
- k = 3-10 for low dimensions (d ≤ 3)
- k = 10-20 for higher dimensions (d > 3)
Must satisfy k < n_samples. Higher k reduces variance but increases bias.
base : float, default=np.e
Logarithm base for entropy calculation. Use np.e for nats,
2 for bits, or 10 for dits.
Returns
-------
float
Differential entropy estimate in units determined by base.
Can be negative for continuous distributions.
Raises
------
ValueError
If x is empty, k is invalid, or base is not positive.
Notes
-----
The Kozachenko-Leonenko estimator is:
H(X) = ψ(n) - ψ(k) + d*log(2) + d*<log(ε_k)>
where:
- ψ is the digamma function
- n is the number of samples
- d is the dimensionality
- ε_k is the distance to the k-th nearest neighbor
- <·> denotes average over all samples
Small noise is added to break ties for discrete-valued continuous data.
References
----------
Kraskov, A., Stögbauer, H., & Grassberger, P. (2004). Estimating mutual
information. Physical Review E, 69(6), 066138. (Recommends k ≈ 4)
Examples
--------
>>> # Entropy of standard normal (theoretical: 0.5*log(2πe) ≈ 1.42 nats)
>>> np.random.seed(42)
>>> x = np.random.randn(10000)
>>> h = nonparam_entropy_c(x)
>>> print(f"Estimated entropy: {h:.3f} nats")
Estimated entropy: 1.422 nats
>>> # Entropy in bits
>>> h_bits = nonparam_entropy_c(x, base=2)
>>> print(f"h_bits = {h_bits:.3f} bits")
h_bits = 2.051 bits
"""
# Validate inputs
x = np.asarray(x)
if x.size == 0:
raise ValueError("x cannot be empty")
if len(x.shape) == 1:
x = x.reshape(-1, 1)
n_elements, n_features = x.shape
if not isinstance(k, (int, np.integer)) or k <= 0:
raise ValueError(f"k must be positive integer, got {k}")
if k >= n_elements:
raise ValueError(f"k must be less than n_samples, got k={k} with n_samples={n_elements}")
if base <= 0:
raise ValueError(f"base must be positive, got {base}")
x = add_noise(x)
tree = build_tree(x)
nn = query_neighbors(tree, x, k)
const = py_fast_digamma(n_elements) - py_fast_digamma(k) + n_features * log(2)
return (const + n_features * np.log(nn).mean()) / log(base)
def nonparam_cond_entropy_cc(x, y, k=DEFAULT_NN, base=np.e):
"""The classic K-L k-nearest neighbor continuous entropy estimator for the
entropy of X conditioned on Y.
Computes H(X|Y) = H(X,Y) - H(Y) using the Kozachenko-Leonenko estimator.
This measures the remaining uncertainty in X when Y is known.
Parameters
----------
x : array-like
Variable whose conditional entropy is computed, shape (n_samples,) or
(n_samples, n_features_x). Will be reshaped to 2D if 1D.
y : array-like
Conditioning variable, shape (n_samples,) or (n_samples, n_features_y).
Must have same number of samples as x.
k : int, default=5
Number of nearest neighbors for the estimator. Higher k reduces variance
but increases bias. Common values are 3-10.
base : float, default=np.e
Logarithm base for entropy calculation. Use np.e for nats, 2 for bits.
Returns
-------
float
Conditional entropy H(X|Y) in units determined by base parameter.
Can be negative since this is differential entropy for continuous variables.
Lower values (more negative) indicate Y provides more information about X.
Notes
-----
Uses the chain rule: H(X|Y) = H(X,Y) - H(Y), computing each term with
the KL entropy estimator. Small random noise is added to handle discrete
or repeated values.
See Also
--------
nonparam_entropy_c : Computes unconditional entropy H(X)
nonparam_mi_cc : Computes mutual information I(X;Y)
Raises
------
ValueError
If x or y are empty or have different numbers of samples."""
# Validate inputs
x = np.asarray(x)
y = np.asarray(y)
if x.size == 0 or y.size == 0:
raise ValueError("x and y cannot be empty")
# Reshape to ensure 2D
if x.ndim == 1:
x = x.reshape(-1, 1)
if y.ndim == 1:
y = y.reshape(-1, 1)
if len(x) != len(y):
raise ValueError(f"x and y must have same number of samples, got {len(x)} and {len(y)}")
xy = np.c_[x, y]
entropy_union_xy = nonparam_entropy_c(xy, k=k, base=base)
entropy_y = nonparam_entropy_c(y, k=k, base=base)
return entropy_union_xy - entropy_y
[docs]
def nonparam_mi_cc(
x,
y,
z=None,
k=DEFAULT_NN,
base=np.e,
alpha="auto",
lf=5,
precomputed_tree_x=None,
precomputed_tree_y=None,
):
"""Kraskov-Stögbauer-Grassberger (KSG) mutual information estimator.
Estimates mutual information between continuous variables using k-nearest
neighbors. Can compute conditional MI when z is provided: I(X;Y|Z).
Parameters
----------
x : array-like
First variable, shape (n_samples,) or (n_samples, n_features_x).
y : array-like
Second variable, shape (n_samples,) or (n_samples, n_features_y).
Must have same number of samples as x.
z : array-like, optional
Conditioning variable for conditional MI: I(X;Y|Z).
Shape (n_samples,) or (n_samples, n_features_z).
k : int, default=5
Number of nearest neighbors. Common values:
- k = 4-5 for most applications
- Use larger k for higher dimensions
Must satisfy k < n_samples.
base : float, default=np.e
Logarithm base. Use np.e for nats, 2 for bits, 10 for dits.
alpha : float or "auto", default="auto"
Local Non-uniformity Correction (LNC) parameter.
- "auto": automatically selects optimal alpha
- float: manual alpha value (0 disables correction)
- Warning: LNC disabled when k ≤ dimensionality
lf : int, default=5
Leaf size for k-d tree construction. Smaller values may be
faster for small datasets, larger values for big datasets.
precomputed_tree_x : BallTree/KDTree, optional
Pre-built tree for x to avoid recomputation in repeated calls.
precomputed_tree_y : BallTree/KDTree, optional
Pre-built tree for y to avoid recomputation in repeated calls.
Returns
-------
float
Mutual information estimate in units determined by base.
Always non-negative (up to estimation error).
Notes
-----
Uses the KSG estimator algorithm 1:
I(X;Y) = ψ(k) - <ψ(n_x + 1) + ψ(n_y + 1)> + ψ(n)
where:
- ψ is the digamma function
- n_x, n_y are the number of neighbors in X, Y spaces
- <·> denotes average over all samples
- n is the total number of samples
Small noise is added to continuous variables to break ties.
References
----------
Kraskov, A., Stögbauer, H., & Grassberger, P. (2004). Estimating mutual
information. Physical Review E, 69(6), 066138.
Gao, S., Ver Steeg, G., & Galstyan, A. (2015). Efficient Estimation of Mutual
Information for Strongly Dependent Variables. AISTATS, PMLR 38:277-286. (LNC correction)
Examples
--------
>>> # MI between correlated Gaussians
>>> np.random.seed(42)
>>> x = np.random.randn(1000)
>>> y = x + np.random.randn(1000) * 0.5
>>> mi = nonparam_mi_cc(x, y)
>>> print(f"MI = {mi:.3f} nats")
MI = 0.760 nats
>>> # Conditional MI: I(X;Y|Z)
>>> z = np.random.randn(1000)
>>> cmi = nonparam_mi_cc(x, y, z=z)
>>> print(f"cmi = {cmi:.3f} nats")
cmi = 0.756 nats
Raises
------
ValueError
If arrays have different lengths, k is invalid, or base is not positive."""
# Validate inputs
if not isinstance(k, (int, np.integer)) or k <= 0:
raise ValueError(f"k must be positive integer, got {k}")
if base <= 0:
raise ValueError(f"base must be positive, got {base}")
x = np.asarray(x)
y = np.asarray(y)
if x.size == 0 or y.size == 0:
raise ValueError("x and y cannot be empty")
if len(x) != len(y):
raise ValueError(f"Arrays should have same length, got {len(x)} and {len(y)}")
if k >= len(x):
raise ValueError(f"k must be less than n_samples, got k={k} with n_samples={len(x)}")
x, y = np.asarray(x), np.asarray(y)
x, y = x.reshape(x.shape[0], -1), y.reshape(y.shape[0], -1)
x = add_noise(x)
y = add_noise(y)
points = [x, y]
if z is not None:
z = np.asarray(z)
z = z.reshape(z.shape[0], -1)
points.append(z)
points = np.hstack(points)
d = points.shape[1] # Total dimensionality
# Auto-select alpha if requested
if alpha == "auto":
alpha = get_lnc_alpha(k, d)
# Disable LNC correction if k <= d to avoid instability
if k <= d:
import warnings
warnings.warn(
f"LNC correction disabled: k={k} <= dimensionality={d}. "
f"LNC requires k > d for stability. Consider using k >= {d+1}.",
UserWarning,
)
alpha = 0
# Find nearest neighbors in joint space, p=inf means max-norm
tree = build_tree(points, lf=lf)
dvec = query_neighbors(tree, points, k)
if z is None:
a = avgdigamma(x, dvec, tree=precomputed_tree_x, lf=lf)
b = avgdigamma(y, dvec, tree=precomputed_tree_y, lf=lf)
c = py_fast_digamma(k)
d = py_fast_digamma(len(x))
# print(a, b, c, d)
if isinstance(alpha, (int, float)) and alpha > 0:
d += lnc_correction(tree, points, k, alpha)
else:
xz = np.c_[x, z]
yz = np.c_[y, z]
a, b, c, d = (
avgdigamma(xz, dvec),
avgdigamma(yz, dvec),
avgdigamma(z, dvec),
py_fast_digamma(k),
)
return (-a - b + c + d) / log(base)
def lnc_correction(tree, points, k, alpha):
"""Local Non-uniformity Correction for KSG mutual information estimator.
Implements the Local Non-uniformity Correction (LNC) to improve KSG estimator
accuracy when data exhibits local non-uniformity. The correction detects regions
where k-nearest neighbors are aligned along lower-dimensional manifolds and
adjusts the entropy estimate accordingly.
Parameters
----------
tree : sklearn.neighbors.KDTree or BallTree
Pre-built tree structure for efficient nearest neighbor queries on the
joint space data points.
points : ndarray of shape (n_samples, n_features)
The joint space data points (typically concatenated X and Y variables
for mutual information estimation). Must have at least k+1 samples.
k : int
Number of nearest neighbors used in the KSG estimator. Should be > d
where d is the dimensionality, with k >= d+2 recommended for stability.
Must be positive.
alpha : float
Threshold parameter for detecting local non-uniformity. The correction
is applied when Volume_PCA / Volume_axis < alpha. Typical values from
the lookup table range from 0.18 to 1.8. Larger values make the
correction more aggressive (applied more often). Set to 0 to disable.
Must be non-negative.
Returns
-------
float
The correction term to be added to the entropy/MI estimate. Always >= 0.
Raises
------
ValueError
If inputs have invalid shapes or values.
Notes
-----
The algorithm works by:
1. For each point, finding its k nearest neighbors
2. Computing PCA on these neighbors (after mean-centering)
3. Comparing the volume of the PCA-aligned bounding box vs axis-aligned box
4. If PCA box is significantly smaller (by factor alpha), a correction is applied
This detects when neighbors lie on a lower-dimensional manifold, which would
cause the standard KSG estimator to overestimate entropy.
Warning: Can be unstable when k <= d. In practice, the correction is often
disabled (alpha=0) when k is too small relative to dimensionality.
See Also
--------
get_lnc_alpha : Get optimal alpha value from lookup table
nonparam_mi_cc : Main MI estimator that uses this correction with alpha="auto"
References
----------
Gao et al. (2015). Efficient estimation of mutual information for strongly
dependent variables. AISTATS."""
# Validate inputs
points = np.asarray(points)
if points.ndim != 2:
raise ValueError(f"points must be 2D array, got shape {points.shape}")
if not isinstance(k, (int, np.integer)) or k <= 0:
raise ValueError(f"k must be positive integer, got {k}")
if not isinstance(alpha, (int, float, np.number)) or alpha < 0:
raise ValueError(f"alpha must be non-negative number, got {alpha}")
if points.shape[0] <= k:
raise ValueError(f"Need at least k+1 points, got {points.shape[0]} points with k={k}")
# Early return if alpha is 0 (correction disabled)
if alpha == 0:
return 0.0
e = 0
n_sample = points.shape[0]
for point in points:
# Find k-nearest neighbors in joint space, p=inf means max norm
knn = tree.query(point[None, :], k=k + 1, return_distance=False)[0]
knn_points = points[knn]
# Subtract mean of k-nearest neighbor points (fixed from subtracting first point)
knn_points = knn_points - np.mean(knn_points, axis=0)
# Calculate covariance matrix of k-nearest neighbor points, obtain eigen vectors
covr = knn_points.T @ knn_points / k
# Use eigh for symmetric matrix (more stable than eig)
try:
_, v = la.eigh(covr)
except la.LinAlgError:
# Skip this point if covariance is singular
continue
# Calculate PCA-bounding box using eigen vectors
V_rect = np.log(np.abs(knn_points @ v).max(axis=0)).sum()
# Calculate the volume of original box
log_knn_dist = np.log(np.abs(knn_points).max(axis=0)).sum()
# Perform local non-uniformity checking and update correction term
if V_rect < log_knn_dist + np.log(alpha):
e += log_knn_dist - V_rect
# Normalize by number of samples at the end (more efficient)
return e / n_sample
[docs]
def nonparam_mi_cd(x_continuous, y_discrete, k=DEFAULT_NN, base=np.e):
"""
Mutual information between continuous and discrete variables using KSG estimator.
Uses the mixed-type mutual information estimator from the KSG paper.
Parameters
----------
x_continuous : array_like
Continuous variable data of shape (n_samples,) or (n_samples, n_features).
Should contain finite values.
y_discrete : array_like
Discrete variable data of shape (n_samples,). Values should be discrete
categories (integers or strings).
k : int, optional
Number of nearest neighbors to use. Default is 5. Must be positive.
base : float, optional
Logarithm base. Default is e (natural logarithm). Must be positive.
Returns
-------
float
Mutual information in units determined by base. Always non-negative.
Raises
------
ValueError
If inputs have incompatible shapes or invalid values.
Notes
-----
Computes MI as I(X;Y) = H(X) - H(X|Y) where H(X|Y) is the weighted average
of conditional entropies. Categories with fewer than k+1 samples are skipped,
which may introduce bias for small sample sizes."""
# Validate inputs
if not isinstance(k, (int, np.integer)) or k <= 0:
raise ValueError(f"k must be positive integer, got {k}")
if base <= 0:
raise ValueError(f"base must be positive, got {base}")
x_continuous = np.asarray(x_continuous)
y_discrete = np.asarray(y_discrete)
if x_continuous.size == 0 or y_discrete.size == 0:
raise ValueError("x_continuous and y_discrete cannot be empty")
if len(x_continuous.shape) == 1:
x_continuous = x_continuous.reshape(-1, 1)
if len(x_continuous) != len(y_discrete):
raise ValueError(
f"Arrays should have same length, got {len(x_continuous)} and {len(y_discrete)}"
)
if k >= len(x_continuous):
raise ValueError(
f"k must be less than n_samples, got k={k} with n_samples={len(x_continuous)}"
)
n_samples = len(x_continuous)
# Add small noise to continuous variables to break ties
x_continuous = add_noise(x_continuous)
# Calculate H(X) - H(X|Y)
# H(X) is the entropy of the continuous variable
h_x = nonparam_entropy_c(x_continuous, k=k, base=base)
# H(X|Y) is the conditional entropy
h_x_given_y = 0.0
unique_y = np.unique(y_discrete)
for y_val in unique_y:
mask = y_discrete == y_val
p_y = np.sum(mask) / n_samples
if p_y > 0:
x_subset = x_continuous[mask]
if len(x_subset) > k:
h_x_y = nonparam_entropy_c(x_subset, k=min(k, len(x_subset) - 1), base=base)
h_x_given_y += p_y * h_x_y
mi = h_x - h_x_given_y
return max(0, mi) # MI is non-negative
[docs]
def nonparam_mi_dc(x_discrete, y_continuous, k=DEFAULT_NN, base=np.e):
"""
Mutual information between discrete and continuous variables using KSG estimator.
This is just the symmetric version of nonparam_mi_cd.
Parameters
----------
x_discrete : array_like
Discrete variable data of shape (n_samples,). Values should be discrete
categories (integers or strings).
y_continuous : array_like
Continuous variable data of shape (n_samples,) or (n_samples, n_features).
Should contain finite values.
k : int, optional
Number of nearest neighbors to use. Default is 5. Must be positive.
base : float, optional
Logarithm base. Default is e (natural logarithm). Must be positive.
Returns
-------
float
Mutual information in units determined by base. Always non-negative.
Notes
-----
MI is symmetric, so this function simply swaps the arguments and calls
nonparam_mi_cd. See that function for implementation details."""
# MI is symmetric, so we can just swap the arguments
return nonparam_mi_cd(y_continuous, x_discrete, k=k, base=base)