Source code for driada.utils.jit

"""JIT compilation utilities for DRIADA.

Provides conditional JIT compilation based on environment settings.
"""

import os

# Check if Numba should be disabled
DRIADA_DISABLE_NUMBA = os.getenv("DRIADA_DISABLE_NUMBA", "False").lower() in (
    "true",
    "1",
    "yes",
)

# Try to import numba
try:
    from numba import njit

    NUMBA_AVAILABLE = True
except ImportError:
    NUMBA_AVAILABLE = False

    # Define dummy decorators
    def njit(*args, **kwargs):
        """Dummy njit decorator when numba is not available.

        This function acts as a pass-through decorator that returns
        the original function unchanged when numba is not installed
        or disabled.

        Parameters
        ----------
        *args
            Positional arguments (function to decorate if called directly)
        **kwargs
            Keyword arguments (ignored)

        Returns
        -------
        function or decorator
            Original function or decorator that returns original function
        """

        def decorator(func):
            """Identity decorator that returns function unchanged.

            Parameters
            ----------
            func : callable
                Function to decorate (returned unchanged)

            Returns
            -------
            callable
                The original function without modification
            """
            return func

        return decorator if not args else args[0]

    prange = range


[docs] def conditional_njit(*args, **kwargs): """Conditionally apply numba JIT compilation based on environment settings. If DRIADA_DISABLE_NUMBA environment variable is set to 'true', '1', or 'yes', or if numba is not available, this returns the original function without JIT compilation. Otherwise, applies numba.njit with the given parameters. Parameters ---------- *args Positional arguments passed to numba.njit. If a single function is passed, it will be decorated directly. **kwargs Keyword arguments passed to numba.njit (e.g., parallel=True, cache=True). Returns ------- decorator or function If called with arguments: returns a decorator function. If called on a function directly: returns the (possibly JIT-compiled) function. Notes ----- This decorator allows DRIADA to gracefully handle environments where Numba is not installed or where JIT compilation needs to be disabled for debugging. The DRIADA_DISABLE_NUMBA environment variable can be set to 'true', '1', or 'yes' (case insensitive) to disable JIT compilation globally. Examples -------- >>> @conditional_njit ... def fast_computation(x): ... return x ** 2 With numba parameters:: @conditional_njit(parallel=True) def parallel_computation(x): return x ** 2 Direct decoration (less common):: def my_function(x): return x ** 3 fast_func = conditional_njit(my_function) See Also -------- is_jit_enabled : Check if JIT compilation is currently enabled. """ if DRIADA_DISABLE_NUMBA or not NUMBA_AVAILABLE: # Return identity decorator def decorator(func): """Identity decorator that returns function unchanged. Parameters ---------- func : callable Function to decorate (returned unchanged) Returns ------- callable The original function without modification """ return func return decorator if not args else args[0] else: # Use actual njit return njit(*args, **kwargs)
[docs] def is_jit_enabled(): """Check if JIT compilation is enabled. Determines whether Numba JIT compilation is available and active based on installation status and environment settings. Returns ------- bool True if both conditions are met: - Numba is installed and importable - DRIADA_DISABLE_NUMBA environment variable is not set to 'true', '1', or 'yes' Examples -------- >>> is_jit_enabled() # doctest: +SKIP True # If Numba is installed and not disabled Disable JIT via environment:: import os os.environ['DRIADA_DISABLE_NUMBA'] = '1' is_jit_enabled() # Returns False Notes ----- JIT compilation significantly speeds up numerical computations but may cause issues during debugging. The DRIADA_DISABLE_NUMBA environment variable can be set to 'true', '1', or 'yes' (case insensitive) to disable JIT when debugging or if encountering Numba-related errors. See Also -------- jit_info : Print detailed JIT status information. conditional_njit : Decorator that respects JIT settings. """ return NUMBA_AVAILABLE and not DRIADA_DISABLE_NUMBA
[docs] def jit_info(): """Print information about JIT compilation status. Displays comprehensive information about the JIT compilation environment, including Numba availability, version, and current configuration settings. Prints ------ - Whether Numba is installed - If JIT is disabled via environment variable - Overall JIT enabled status - Numba version (if available) Examples -------- >>> jit_info() # doctest: +SKIP Numba available: True JIT disabled by environment: False JIT enabled: True Numba version: 0.60.0 Notes ----- Useful for debugging performance issues or verifying that JIT compilation is working as expected in your environment. See Also -------- is_jit_enabled : Check JIT status programmatically. conditional_njit : Decorator that respects JIT settings. """ print(f"Numba available: {NUMBA_AVAILABLE}") print(f"JIT disabled by environment: {DRIADA_DISABLE_NUMBA}") print(f"JIT enabled: {is_jit_enabled()}") if NUMBA_AVAILABLE: import numba print(f"Numba version: {numba.__version__}")