Visualization Utilities

Visualization utilities for DRIADA

This module provides reusable visualization functions for embedding comparisons, trajectory plots, and component interpretation in dimensionality reduction analyses.

driada.utils.visual.plot_embedding_comparison(embeddings, features=None, feature_names=None, methods=None, with_trajectory=True, compute_metrics=True, trajectory_kwargs=None, figsize=None, scatter_size=2, save_path=None, dpi=150)[source]

Create comprehensive embedding comparison figure with behavioral features and trajectories.

Parameters:
  • embeddings (dict) – Dictionary mapping method names to embedding arrays (n_samples, n_components). Arrays must be 2D with at least 2 components.

  • features (dict, optional) – Dictionary mapping feature names to feature arrays. Arrays must have same length as embeddings. Features with ‘angle’ or ‘direction’ in the name are treated as circular and plotted with the ‘hsv’ colormap; all others use ‘viridis’. The first two features are shown as rows 1 and 2.

  • feature_names (dict, optional) – Dictionary mapping feature keys to display names

  • methods (list of str, optional) – List of methods to plot (if None, uses all keys in embeddings)

  • with_trajectory (bool, default True) – Whether to include trajectory visualization as a third row

  • compute_metrics (bool, default True) – Whether to compute and display metrics (density contours, percentiles)

  • trajectory_kwargs (dict, optional) – Additional keyword arguments for trajectory plotting

  • figsize (tuple, optional) – Figure size (width, height). If None, computed based on number of methods

  • scatter_size (float, default 2) – Marker size for scatter points

  • save_path (str, optional) – Path to save the figure

  • dpi (int, default DEFAULT_DPI) – DPI resolution for saved figure

Returns:

fig – The generated figure

Return type:

matplotlib.figure.Figure

Raises:

ValueError – If embeddings are not 2D arrays with at least 2 components, or if feature arrays have mismatched lengths

Notes

Methods not found in embeddings dict are silently skipped. KDE computation failures are caught and contours are omitted.

driada.utils.visual.plot_trajectories(embeddings, methods=None, trajectory_kwargs=None, figsize=None, save_path=None, dpi=150)[source]

Create figure showing trajectories in embedding space for multiple methods.

Parameters:
  • embeddings (dict) – Dictionary mapping method names to embedding arrays. Arrays must be 2D with at least 2 components.

  • methods (list of str, optional) – List of methods to plot (if None, uses all keys in embeddings)

  • trajectory_kwargs (dict, optional) – Keyword arguments for trajectory plotting

  • figsize (tuple, optional) – Figure size

  • save_path (str, optional) – Path to save the figure

  • dpi (int, default DEFAULT_DPI) – DPI resolution for saved figure

Returns:

fig – The generated figure

Return type:

matplotlib.figure.Figure

Raises:

ValueError – If embeddings are not 2D arrays with at least 2 components

driada.utils.visual.plot_component_interpretation(mi_matrices, feature_names, methods=None, n_components=None, metadata=None, compute_metrics=True, figsize=None, save_path=None, dpi=150)[source]

Create figure showing mutual information between embedding components and features.

Parameters:
  • mi_matrices (dict) – Dictionary mapping method names to MI matrices (n_features, n_components). MI values should be non-negative.

  • feature_names (list of str) – Names of features for y-axis labels

  • methods (list of str, optional) – List of methods to plot (if None, uses all keys in mi_matrices)

  • n_components (int, optional) – Number of components to show (default: min 5 or available)

  • metadata (dict, optional) – Dictionary of metadata for each method (e.g., explained variance for PCA)

  • compute_metrics (bool, default True) – Whether to show additional metrics (e.g., explained variance)

  • figsize (tuple, optional) – Figure size

  • save_path (str, optional) – Path to save the figure

  • dpi (int, default DEFAULT_DPI) – DPI resolution for saved figure

Returns:

fig – The generated figure

Return type:

matplotlib.figure.Figure

Raises:

ValueError – If MI matrices are not 2D or contain negative values

driada.utils.visual.plot_embeddings_grid(embeddings, labels=None, methods=None, scenarios=None, metrics=None, colormap='viridis', figsize=None, n_cols=4, save_path=None, dpi=150)[source]

Create grid of embeddings for multiple methods and scenarios.

Parameters:
  • embeddings (dict of dict) – Nested dictionary: {method: {scenario: embedding_array}}. Arrays must be 2D with at least 2 components.

  • labels (array or dict, optional) – Color labels for points. Can be array (same for all) or dict matching structure

  • methods (list, optional) – Methods to plot (default: all in embeddings)

  • scenarios (list, optional) – Scenarios to plot (default: all available)

  • metrics (dict, optional) – Nested dict of metrics: {method: {scenario: {metric_name: value}}}. At most 2 metrics shown per subplot.

  • colormap (str) – Colormap for scatter plots

  • figsize (tuple, optional) – Figure size

  • n_cols (int) – Number of columns in grid

  • save_path (str, optional) – Path to save figure

  • dpi (int, default DEFAULT_DPI) – DPI resolution for saved figure

Returns:

fig – The generated figure, or None if no valid embeddings to plot

Return type:

matplotlib.figure.Figure or None

Raises:

ValueError – If embeddings are not 2D or label lengths mismatch

driada.utils.visual.plot_neuron_selectivity_summary(selectivity_counts, total_neurons, colors=None, figsize=(8, 6), save_path=None, dpi=150)[source]

Create bar plot summarizing neuron selectivity categories.

Parameters:
  • selectivity_counts (dict) – Dictionary mapping category names to counts. Counts should be non-negative integers with sum <= total_neurons.

  • total_neurons (int) – Total number of neurons. Must be positive.

  • colors (dict, optional) – Dictionary mapping category names to colors

  • figsize (tuple) – Figure size

  • save_path (str, optional) – Path to save figure

  • dpi (int, default DEFAULT_DPI) – DPI resolution for saved figure

Returns:

fig

Return type:

matplotlib.figure.Figure

Raises:

ValueError – If total_neurons <= 0 or counts are invalid

driada.utils.visual.plot_component_selectivity_heatmap(selectivity_matrix, methods, n_components_per_method=None, figsize=None, save_path=None, dpi=150)[source]

Create heatmap showing neuron selectivity to embedding components.

Parameters:
  • selectivity_matrix (ndarray) – Matrix of shape (n_neurons, total_components) with MI values. Must be 2D with non-negative values.

  • methods (list of str) – List of DR method names. Cannot be empty.

  • n_components_per_method (dict, optional) – Number of components for each method. If None, assumes equal distribution across methods.

  • figsize (tuple, optional) – Figure size (width, height)

  • save_path (str, optional) – Path to save figure

  • dpi (int, default DEFAULT_DPI) – DPI resolution for saved figure

Returns:

fig – The generated figure

Return type:

matplotlib.figure.Figure

Raises:

ValueError – If selectivity_matrix is not 2D, contains negative values, methods list is empty, or component counts don’t match matrix

driada.utils.visual.compute_circular_coordinates(embedding)[source]

Convert 2D embedding to circular coordinates (angles).

Parameters:

embedding (ndarray) – 2D embedding array of shape (n_samples, 2)

Returns:

angles – Angles in radians [0, 2pi] of shape (n_samples,)

Return type:

ndarray

driada.utils.visual.visualize_circular_manifold(embeddings, true_angles, method_names, save_path=None, dpi=150)[source]

Visualize circular manifold extraction from different DR methods.

Creates a figure with two rows: - Top row: 2D embeddings colored by true head direction - Bottom row: True vs reconstructed angle scatter plots

Parameters:
  • embeddings (list of ndarray) – List of 2D embedding arrays, each of shape (n_samples, 2)

  • true_angles (ndarray) – Ground truth angles in radians, shape (n_samples,)

  • method_names (list of str) – Names of DR methods for plot titles

  • save_path (str, optional) – Path to save figure

  • dpi (int, default DEFAULT_DPI) – DPI resolution for saved figure

Returns:

fig – The generated figure

Return type:

matplotlib.figure.Figure

Advanced visualization utilities and plotting functions.

Plotting Utilities

Plotting utilities for DRIADA.

Provides functions for creating publication-quality figures with consistent styling.

driada.utils.plot.make_beautiful(ax, spine_width=4, tick_width=4, tick_length=8, tick_pad=15, tick_labelsize=26, label_size=30, title_size=30, legend_fontsize=18, dpi=None, lowercase_labels=True, legend_frameon=False, legend_loc='auto', legend_offset=0.15, legend_ncol=None, tight_layout=True, remove_origin_tick=False, panel_size=None, panel_units='cm', style=None)[source]

Apply publication-quality styling to a matplotlib axis with optional auto-scaling.

Parameters:
  • ax (matplotlib.axes.Axes) – The axis to style.

  • spine_width (float, optional) – Width of visible spines (default: 4).

  • tick_width (float, optional) – Width of tick marks (default: 4).

  • tick_length (float, optional) – Length of tick marks (default: 8).

  • tick_pad (float, optional) – Padding between ticks and labels (default: 15).

  • tick_labelsize (int, optional) – Font size for tick labels (default: 26).

  • label_size (int, optional) – Font size for axis labels (default: 30).

  • title_size (int, optional) – Font size for title (default: 30).

  • legend_fontsize (int, optional) – Font size for legend (default: 18).

  • dpi (int, optional) – DPI for the figure. If provided, sets the figure’s DPI.

  • lowercase_labels (bool, optional) – Whether to convert all labels and legend text to lowercase (default: True). This includes axis labels, title, tick labels, and legend entries.

  • legend_frameon (bool, optional) – Whether to draw frame around legend (default: False).

  • legend_loc (str, optional) – Legend location (default: ‘auto’). Can be: - ‘auto’: Use matplotlib’s automatic placement - ‘above’: Place legend above the plot, spanning full x-axis width - ‘below’: Place legend below the plot, spanning full x-axis width - Any valid matplotlib location string (e.g., ‘upper right’, ‘center left’)

  • legend_offset (float, optional) – Vertical offset for ‘above’ and ‘below’ legend positions (default: 0.15). Positive values move the legend further from the plot.

  • legend_ncol (int, optional) – Number of columns for legend entries (default: None, auto-determined). For ‘above’ and ‘below’, defaults to number of legend entries (single row).

  • tight_layout (bool, optional) – Whether to remove extra margins on both axes (default: True).

  • remove_origin_tick (bool, optional) – Whether to remove tick labels at the origin (0,0) to avoid overlap (default: False).

  • panel_size (tuple of float, optional) – Physical size (width, height) of the panel. When combined with a StylePreset, enables automatic scaling. If None, no scaling is applied (default: None).

  • panel_units ({'cm', 'inches'}, default 'cm') – Units for panel_size (default: ‘cm’).

  • style (StylePreset, optional) – StylePreset instance to use for styling. If provided with panel_size, automatically applies scaled styling. This is the recommended approach for multi-panel figures. If None, uses the individual parameters (spine_width, tick_width, etc.) directly (default: None).

Returns:

The styled axis.

Return type:

matplotlib.axes.Axes

Notes

This function applies a consistent publication-quality style to matplotlib axes by: - Hiding top and right spines - Setting spine and tick widths - Configuring font sizes for all text elements - Setting figure DPI if requested

The function modifies the axis in-place and returns it for convenience.

Examples

>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots()
>>> _ = ax.plot([1, 2, 3], [1, 4, 9])
>>> _ = make_beautiful(ax)  # Apply styling
>>> plt.show()
>>> # With custom styling
>>> fig, ax = plt.subplots()
>>> _ = ax.plot([1, 2, 3], [1, 4, 9])
>>> _ = make_beautiful(ax, spine_width=2, tick_labelsize=14)

See also

create_default_figure

Create figure with default styling applied.

PanelLayout

Layout manager for multi-panel figures with precise dimensions.

StylePreset

Style presets with automatic scaling.

driada.utils.plot.create_default_figure(figsize=(16, 12), nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, **style_kwargs)[source]

Create a figure with default publication-quality styling.

Parameters:
  • figsize (tuple of float, optional) – Figure size as (width, height) in inches (default: (16, 12)).

  • nrows (int, optional) – Number of rows in subplot grid (default: 1).

  • ncols (int, optional) – Number of columns in subplot grid (default: 1).

  • sharex (bool, optional) – Whether to share x-axis among subplots (default: False).

  • sharey (bool, optional) – Whether to share y-axis among subplots (default: False).

  • squeeze (bool, optional) – If True, extra dimensions are squeezed out from returned axes (default: True).

  • **style_kwargs (dict) – Additional keyword arguments passed to make_beautiful().

Return type:

Tuple[Figure, Any]

Returns:

  • fig (matplotlib.figure.Figure) – The created figure.

  • axes (matplotlib.axes.Axes or array of Axes) – The styled axis/axes. If nrows=ncols=1 and squeeze=True, returns single Axes. Otherwise returns array of Axes.

Notes

This is a convenience function that combines matplotlib’s subplots() with automatic application of publication-quality styling via make_beautiful(). All axes in the figure receive the same styling.

Examples

>>> # Single subplot with default styling
>>> fig, ax = create_default_figure()
>>> _ = ax.plot([1, 2, 3], [1, 4, 9])
>>> plt.show()
>>> # Multiple subplots with custom styling
>>> fig, axes = create_default_figure(nrows=2, ncols=2, figsize=(20, 16),
...                                   spine_width=2, tick_labelsize=14)
>>> for ax in axes.flat:
...     _ = ax.plot(np.random.randn(100))

See also

make_beautiful

Apply styling to existing axes.

matplotlib.pyplot.subplots

Base function for creating subplots.

driada.utils.plot.plot_mat(mat, figsize=(12, 12), ax=None, with_cbar=True, cmap='viridis', aspect='auto', **imshow_kwargs)[source]

Plot a matrix as an image with optional colorbar.

Parameters:
  • mat (np.ndarray) – 2D array to plot.

  • figsize (tuple of float, optional) – Figure size if creating new figure (default: (12, 12)).

  • ax (matplotlib.axes.Axes, optional) – Existing axis to plot on. If None, creates new figure.

  • with_cbar (bool, optional) – Whether to add a colorbar (default: True).

  • cmap (str, optional) – Colormap name (default: ‘viridis’).

  • aspect (str, optional) – Aspect ratio setting (default: ‘auto’).

  • **imshow_kwargs (dict) – Additional keyword arguments passed to ax.imshow().

Return type:

Tuple[Optional[Figure], Axes]

Returns:

  • fig (matplotlib.figure.Figure or None) – The figure (None if ax was provided).

  • ax (matplotlib.axes.Axes) – The axis with the plot.

Raises:

ValueError – If mat is not a 2D array.

Notes

This function is a convenience wrapper around matplotlib’s imshow for visualizing 2D matrices. It handles figure creation and colorbar addition automatically.

The function returns both figure and axis to allow further customization. If an existing axis is provided, the figure return value will be None.

Examples

>>> # Plot a random matrix
>>> mat = np.random.randn(10, 10)
>>> fig, ax = plot_mat(mat)
>>> _ = ax.set_title('Random Matrix')
>>> plt.show()
>>> # Plot on existing axis without colorbar
>>> fig, ax = plt.subplots()
>>> _, ax = plot_mat(mat, ax=ax, with_cbar=False, cmap='coolwarm')

See also

matplotlib.pyplot.imshow

Base function for displaying images.

matplotlib.pyplot.colorbar

Function for adding colorbars.

Publication Utilities

Publication-ready figure framework for DRIADA.

This module provides a comprehensive framework for creating publication-quality multi-panel figures with precise physical dimensions and consistent styling.

Philosophy: SAME PHYSICAL SIZE ACROSS ALL PANELS

By default, all text and lines maintain the SAME physical size (in cm/inches) across all panels regardless of their dimensions. When printed and measured with a ruler, a 10pt font will be exactly 10pt in every panel, whether the panel is 4×4 cm or 12×12 cm.

This ensures professional, consistent appearance in publications.

Key Features

  • Precise physical dimensions (cm or inches) for each subplot

  • Configurable DPI for different output targets (300 for publication, 150 for draft)

  • Fixed physical sizing: Same font/line sizes across all panels (DEFAULT)

  • Optional area-based scaling for advanced use cases

  • Flexible layout system (grid or custom positioning)

  • Support for external plots (from R, MATLAB, etc.)

  • Panel labeling utilities (A, B, C, …)

  • Seamless integration with driada.utils.plot.make_beautiful()

Quick Start

>>> from driada.utils.publication import PanelLayout, StylePreset
>>>
>>> # Create layout with different panel sizes
>>> layout = PanelLayout(units='cm', dpi=300, spacing={'wspace': 1.5})
>>> layout.add_panel('A', size=(4, 4))  # Small panel
>>> layout.add_panel('B', size=(8, 8))  # Large panel
>>> layout.set_grid(rows=1, cols=2)
>>>
>>> # Create figure - all panels get SAME physical font/line sizes
>>> style = StylePreset.publication_default()
>>> fig, axes = layout.create_figure(style=style)
>>>
>>> # Plot data - fonts will be identical physical size in both panels!
>>> axes['A'].plot(x, y)
>>> axes['A'].set_xlabel('Time (s)')  # 10pt font
>>> axes['B'].plot(x, y)
>>> axes['B'].set_xlabel('Time (s)')  # Also 10pt font (same physical size)
>>>
>>> # Save at specified DPI
>>> fig.savefig('figure.pdf', dpi=layout.dpi, bbox_inches='tight')

See also

make_beautiful

Enhanced with auto-scaling support for single-panel figures.

class driada.utils.publication.PanelLayout(units='cm', dpi=300, spacing=None)[source]

Bases: object

Manages layout and creation of publication-ready multi-panel figures.

This class handles precise physical dimensions for each subplot and generates matplotlib figures with correct sizing and spacing.

Parameters:
  • units ({'cm', 'inches'}, default 'cm') – Physical units for panel dimensions and spacing

  • dpi (int, default 300) – Dots per inch for the figure (300 for publication, 150 for draft, 72 for screen)

  • spacing (dict, optional) – Spacing between panels: - ‘wspace’: horizontal spacing in physical units - ‘hspace’: vertical spacing in physical units Default: {‘wspace’: 0, ‘hspace’: 0}

Examples

>>> # Simple 2-panel layout
>>> layout = PanelLayout(units='cm', dpi=300)
>>> layout.add_panel('A', size=(8, 6))
>>> layout.add_panel('B', size=(8, 6))
>>> layout.set_grid(rows=1, cols=2)
>>> fig, axes = layout.create_figure()
>>> # Complex layout with custom positioning
>>> layout = PanelLayout(units='cm', dpi=300, spacing={'wspace': 1.5, 'hspace': 1.0})
>>> layout.add_panel('A', size=(5, 5), position=(0, 0))
>>> layout.add_panel('B', size=(10, 5), position=(0, 1))
>>> layout.add_panel('C', size=(15, 6), position=(1, 0), col_span=2)
>>> fig, axes = layout.create_figure()
__init__(units='cm', dpi=300, spacing=None)[source]

Initialize PanelLayout.

Parameters:
  • units ({'cm', 'inches'}, default 'cm') – Physical units for panel dimensions and spacing.

  • dpi (int, default 300) – Dots per inch for the figure (300 for publication, 150 for draft, 72 for screen).

  • spacing (dict, optional) – Spacing between panels with keys ‘wspace’ (horizontal) and ‘hspace’ (vertical) in physical units. If None, defaults to {‘wspace’: 0, ‘hspace’: 0}.

add_panel(name, size, position=None, row_span=1, col_span=1, **kwargs)[source]

Add a panel to the layout.

Parameters:
  • name (str) – Identifier for the panel (e.g., ‘A’, ‘B’, ‘C’)

  • size (tuple of float) – (width, height) in units specified by self.units

  • position (tuple of int, optional) – (row, col) grid position. If None, position determined by grid or sequential order

  • row_span (int, default 1) – Number of rows this panel spans

  • col_span (int, default 1) – Number of columns this panel spans

  • **kwargs – For backward compatibility: ‘rowspan’, ‘colspan’

Raises:

ValueError – If panel name already exists, size is invalid, or span is invalid

set_grid(rows, cols)[source]

Set the grid shape for automatic panel positioning.

Parameters:
  • rows (int) – Number of rows in the grid

  • cols (int) – Number of columns in the grid

Return type:

None

get_panel_size(name)[source]

Get the size of a panel in the layout’s units.

Parameters:

name (str) – Panel identifier

Returns:

(width, height) in self.units

Return type:

tuple of float

Raises:

ValueError – If panel name not found

create_figure(style=None)[source]

Create a matplotlib figure with all panels.

Parameters:

style (StylePreset, optional) – Style preset to apply to all panels. If None, no styling is applied.

Return type:

Tuple[Figure, Dict[str, Axes]]

Returns:

  • fig (matplotlib.figure.Figure) – The created figure

  • axes (dict of matplotlib.axes.Axes) – Dictionary mapping panel names to their axes

Raises:

ValueError – If layout validation fails (bounds violations, overlaps, etc.)

Examples

>>> layout = PanelLayout(units='cm', dpi=300)
>>> layout.add_panel('A', size=(8, 6), position=(0, 0))
>>> layout.add_panel('B', size=(8, 6), position=(0, 1))
>>> layout.set_grid(rows=1, cols=2)
>>> fig, axes = layout.create_figure()
>>> axes['A'].plot([1, 2, 3], [1, 4, 9])
create_figure_with_subfigures()[source]

Create a matplotlib figure with SubFigure for each panel.

Use this method when panels need internal subplot structure (e.g., stacked traces, side-by-side plots). Each SubFigure maintains the exact physical dimensions specified and can contain its own subplot layout.

Returns:

  • fig (matplotlib.figure.Figure) – The created figure

  • subfigures (dict of matplotlib.figure.SubFigure) – Dictionary mapping panel names to their SubFigure objects

Return type:

Tuple[Figure, Dict[str, plt.SubFigure]]

Examples

>>> layout = PanelLayout(units='cm', dpi=300)
>>> layout.add_panel('A', size=(11.11, 6.67), position=(0, 0))
>>> layout.add_panel('B', size=(11.11, 7.22), position=(1, 0))
>>> layout.set_grid(rows=2, cols=1)
>>>
>>> fig, subfigs = layout.create_figure_with_subfigures()
>>>
>>> # Create 3 stacked subplots within panel A
>>> axs_a = subfigs['A'].subplots(3, 1, sharex=True)
>>> for ax in axs_a:
...     ax.plot(data)
class driada.utils.publication.PanelSpec(name, size, position=None, row_span=1, col_span=1)[source]

Bases: object

Specification for a single panel in a multi-panel figure.

Parameters:
  • name (str) – Identifier for the panel (e.g., ‘A’, ‘B’, ‘C’)

  • size (tuple of float) – (width, height) in user’s preferred units

  • position (tuple of int, optional) – (row, col) grid position. If None, panels are arranged sequentially

  • row_span (int, optional) – Number of rows this panel spans (default: 1)

  • col_span (int, optional) – Number of columns this panel spans (default: 1)

name: str
size: Tuple[float, float]
position: Optional[Tuple[int, int]] = None
row_span: int = 1
col_span: int = 1
__init__(name, size, position=None, row_span=1, col_span=1)
Parameters:
Return type:

None

class driada.utils.publication.StylePreset(name='default', reference_size=(8.0, 8.0), reference_units='cm', base_spine_width=1.5, base_tick_width=1.5, base_tick_length=6, base_tick_pad=8, base_tick_labelsize=8, base_label_size=10, base_title_size=10, base_legend_fontsize=8, base_line_width=1.5, base_marker_size=3.0, base_annotation_fontsize=8.0, legend_frameon=False, lowercase_labels=False, tight_layout=True, scaling_mode='fixed')[source]

Bases: object

Style preset with consistent physical sizing across all panels.

This class defines styling parameters that maintain the SAME physical size (in cm/inches) across all panels regardless of their dimensions. This ensures that when printed, all text and lines have identical physical measurements.

For advanced use cases, area-based scaling can be enabled to maintain visual density across panels of different sizes instead.

Parameters:
  • name (str) – Name of the preset (e.g., ‘nature’, ‘custom’)

  • reference_size (tuple of float) – (width, height) of reference panel in reference_units. Only used when scaling_mode=’area’.

  • reference_units ({'cm', 'inches'}, default 'cm') – Units for reference_size

  • base_spine_width (float, default 1.5) – Spine width in points (same physical size on all panels)

  • base_tick_width (float, default 1.5) – Tick width in points (same physical size on all panels)

  • base_tick_length (float, default 6) – Tick length in points (same physical size on all panels)

  • base_tick_pad (float, default 8) – Tick padding in points (same physical size on all panels)

  • base_tick_labelsize (float, default 8) – Tick label font size in points (same physical size on all panels)

  • base_label_size (float, default 10) – Axis label font size in points (same physical size on all panels)

  • base_title_size (float, default 10) – Title font size in points (same physical size on all panels)

  • base_legend_fontsize (float, default 8) – Legend font size in points (same physical size on all panels)

  • base_line_width (float, default 1.5) – Line width for data visualization in points

  • base_marker_size (float, default 3.0) – Marker size for scatter plots in points

  • base_annotation_fontsize (float, default 8.0) – Font size for text annotations (e.g., “neuron 36”) in points

  • scaling_mode ({'fixed', 'area'}, default 'fixed') – Scaling behavior: - ‘fixed’: Same physical size on all panels (DEFAULT, recommended) - ‘area’: Scale by sqrt(area_ratio) to preserve visual density

  • legend_frameon (bool)

  • lowercase_labels (bool)

  • tight_layout (bool)

Examples

>>> # Create a publication preset with fixed physical sizing
>>> style = StylePreset.publication_default()
>>> # Apply to any panel - fonts will have same physical size
>>> style.apply_to_axes(ax, (8, 8), 'cm')
>>> style.apply_to_axes(ax2, (4, 4), 'cm')  # Same font size as above!
name: str = 'default'
reference_size: Tuple[float, float] = (8.0, 8.0)
reference_units: Literal['cm', 'inches'] = 'cm'
base_spine_width: float = 1.5
base_tick_width: float = 1.5
base_tick_length: float = 6
base_tick_pad: float = 8
base_tick_labelsize: float = 8
base_label_size: float = 10
base_title_size: float = 10
base_legend_fontsize: float = 8
base_line_width: float = 1.5
base_marker_size: float = 3.0
base_annotation_fontsize: float = 8.0
legend_frameon: bool = False
lowercase_labels: bool = False
tight_layout: bool = True
scaling_mode: Literal['fixed', 'area'] = 'fixed'
calculate_scale_factor(panel_size, panel_units)[source]

Calculate scale factor based on scaling mode.

Parameters:
  • panel_size (tuple of float) – (width, height) of the panel

  • panel_units ({'cm', 'inches'}) – Units for panel_size

Returns:

Scale factor to apply to all visual elements:

  • scaling_mode=’fixed’: Always returns 1.0 (same physical size on all panels)

  • scaling_mode=’area’: Returns sqrt(panel_area / reference_area) to preserve visual density across different panel sizes

Return type:

float

get_line_width(panel_size=None, panel_units='cm')[source]

Get scaled line width for data visualization.

Parameters:
  • panel_size (tuple of float, optional) – (width, height) of the panel. If None, returns base value.

  • panel_units ({'cm', 'inches'}, default 'cm') – Units for panel_size

Returns:

Line width in points, scaled according to scaling_mode

Return type:

float

get_marker_size(panel_size=None, panel_units='cm')[source]

Get scaled marker size for scatter plots.

Parameters:
  • panel_size (tuple of float, optional) – (width, height) of the panel. If None, returns base value.

  • panel_units ({'cm', 'inches'}, default 'cm') – Units for panel_size

Returns:

Marker size in points, scaled according to scaling_mode

Return type:

float

get_annotation_fontsize(panel_size=None, panel_units='cm')[source]

Get scaled font size for text annotations.

Parameters:
  • panel_size (tuple of float, optional) – (width, height) of the panel. If None, returns base value.

  • panel_units ({'cm', 'inches'}, default 'cm') – Units for panel_size

Returns:

Font size in points, scaled according to scaling_mode

Return type:

float

apply_to_axes(ax, panel_size, panel_units='cm')[source]

Apply scaled styling to a matplotlib axes.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to style

  • panel_size (tuple of float) – (width, height) of the panel in panel_units

  • panel_units ({'cm', 'inches'}, default 'cm') – Units for panel_size

Return type:

None

Examples

>>> style = StylePreset.publication_default()
>>> fig, ax = plt.subplots(figsize=(3, 3))
>>> style.apply_to_axes(ax, (8, 8), 'cm')
apply_to_subfigure(subfig, panel_size, panel_units='cm', legend_config=None, margins=None)[source]

Apply consistent styling to ALL axes in a subfigure.

This method is designed for multi-panel figures where each panel is a SubFigure that may contain multiple internal axes (e.g., stacked traces). It ensures all axes within the subfigure have consistent styling.

Parameters:
  • subfig (matplotlib.figure.SubFigure) – The subfigure containing axes to style

  • panel_size (tuple of float) – (width, height) of the panel in panel_units

  • panel_units ({'cm', 'inches'}, default 'cm') – Units for panel_size

  • legend_config (dict, optional) – Configuration for legend positioning: - ‘ax_index’: int, which axis to attach legend to (default: -1, last axis) - ‘loc’: str, ‘below’ or ‘right’ (default: ‘below’) - ‘ncol’: int, number of columns (default: 2) - ‘handles’: list, explicit legend handles (optional) - ‘labels’: list, explicit legend labels (optional)

  • margins (dict, optional) – Subfigure margins as fractions: {‘left’, ‘right’, ‘top’, ‘bottom’, ‘hspace’, ‘wspace’} Default: {‘left’: 0.12, ‘right’: 0.95, ‘top’: 0.95, ‘bottom’: 0.18}

Return type:

None

Examples

>>> layout = PanelLayout(units='cm', dpi=300)
>>> layout.add_panel('A', size=(11.85, 7.0), position=(0, 0))
>>> fig, subfigs = layout.create_figure_with_subfigures()
>>> axs = subfigs['A'].subplots(3, 1, sharex=True)
>>> # ... plot data ...
>>> style = StylePreset.fixed_size()
>>> style.apply_to_subfigure(
...     subfigs['A'],
...     layout.get_panel_size('A'),
...     legend_config={'loc': 'below', 'ncol': 2}
... )
classmethod publication_default(scaling_mode='fixed')[source]

Create a preset with sensible defaults for publication figures.

Provides professional-looking figures with readable fonts and clean lines. By default, uses fixed physical sizing so all panels have identical font and line sizes when measured with a ruler on the printed page.

Parameters:

scaling_mode ({'fixed', 'area'}, default 'fixed') –

  • ‘fixed’: Same physical size on all panels (recommended for most cases)

  • ’area’: Scale by panel area to preserve visual density

Returns:

Configured preset for publication figures

Return type:

StylePreset

Examples

>>> # Default: fixed physical size across all panels
>>> style = StylePreset.publication_default()
>>>
>>> # Optional: area-based scaling for visual density preservation
>>> style_area = StylePreset.publication_default(scaling_mode='area')
classmethod fixed_size(**kwargs)[source]

Create a preset with FIXED sizes across all panels.

This is an explicit alias for the default behavior. All panels get the same font sizes and line widths regardless of their physical dimensions, ensuring consistent absolute physical size when printed.

Note: This is now the DEFAULT behavior, so StylePreset() and StylePreset.publication_default() already use fixed sizing. This method is provided for explicit clarity.

Parameters:

**kwargs (dict) – Override any style parameters (e.g., base_spine_width=2.0)

Returns:

Configured preset with fixed scaling

Return type:

StylePreset

Examples

>>> # All panels get 10pt fonts and 1.5pt lines regardless of size
>>> style = StylePreset.fixed_size()
>>> # Or customize:
>>> style = StylePreset.fixed_size(base_label_size=12, base_spine_width=2.0)
classmethod from_make_beautiful(spine_width=4, tick_width=4, tick_length=8, tick_pad=15, tick_labelsize=26, label_size=30, title_size=30, legend_fontsize=18, reference_size=(16.0, 12.0), reference_units='inches')[source]

Create a preset matching existing make_beautiful() styling.

This allows converting existing code to use the new framework while maintaining the same visual appearance.

Parameters:
  • spine_width (float, default 4) – Spine width from make_beautiful

  • tick_width (float, default 4) – Tick width from make_beautiful

  • tick_length (float, default 8) – Tick length from make_beautiful

  • tick_pad (float, default 15) – Tick padding from make_beautiful

  • tick_labelsize (int, default 26) – Tick label font size from make_beautiful

  • label_size (int, default 30) – Axis label font size from make_beautiful

  • title_size (int, default 30) – Title font size from make_beautiful

  • legend_fontsize (int, default 18) – Legend font size from make_beautiful

  • reference_size (tuple of float, default (16.0, 12.0)) – Reference panel size (matches make_beautiful default figsize)

  • reference_units ({'cm', 'inches'}, default 'inches') – Units for reference_size

Returns:

Configured preset matching make_beautiful

Return type:

StylePreset

copy(**kwargs)[source]

Create a copy of this preset with optional parameter overrides.

Parameters:

**kwargs – Parameters to override in the copy

Returns:

New preset with modified parameters

Return type:

StylePreset

Examples

>>> style = StylePreset.publication_default()
>>> larger_fonts = style.copy(base_label_size=12, base_title_size=14)
>>> area_scaled = style.copy(scaling_mode='area')
__init__(name='default', reference_size=(8.0, 8.0), reference_units='cm', base_spine_width=1.5, base_tick_width=1.5, base_tick_length=6, base_tick_pad=8, base_tick_labelsize=8, base_label_size=10, base_title_size=10, base_legend_fontsize=8, base_line_width=1.5, base_marker_size=3.0, base_annotation_fontsize=8.0, legend_frameon=False, lowercase_labels=False, tight_layout=True, scaling_mode='fixed')
Parameters:
Return type:

None

class driada.utils.publication.ExternalPanel[source]

Bases: object

Utilities for adding external plot images to panels.

This class provides static methods for displaying images from external plotting tools (R, MATLAB, etc.) in matplotlib axes.

static add_image_panel(ax, image_path, aspect='equal', hide_axes=True)[source]

Add an external image to a matplotlib axes.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to add the image to

  • image_path (str) – Path to the image file

  • aspect (str, default 'equal') – Aspect ratio for the image display

  • hide_axes (bool, default True) – Whether to hide axis ticks and labels

Return type:

None

Examples

>>> fig, ax = plt.subplots()
>>> ExternalPanel.add_image_panel(ax, 'external_plot.png')
>>> plt.show()
static create_placeholder(ax, text='External Plot', fontsize=14, color='gray')[source]

Create a placeholder for external content.

Useful during figure development when external plots are not yet ready.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to add the placeholder to

  • text (str, default 'External Plot') – Placeholder text to display

  • fontsize (int, default 14) – Font size for placeholder text

  • color (str, default 'gray') – Color for placeholder elements

Return type:

None

Examples

>>> fig, ax = plt.subplots()
>>> ExternalPanel.create_placeholder(ax, 'R plot goes here')
>>> plt.show()
class driada.utils.publication.PanelLabeler(fontsize_pt=12, location='top_left', offset=None, fontweight='bold', fontfamily='sans-serif')[source]

Bases: object

Utilities for adding panel labels (A, B, C, …) to multi-panel figures.

Parameters:
  • fontsize_pt (float, default 12) – Font size for labels in points

  • location ({'top_left', 'top_right', 'bottom_left', 'bottom_right'}, default 'top_left') – Location for panel labels

  • offset (tuple of float, default (-0.1, 1.05)) – Offset for label position in axes coordinates (x, y)

  • fontweight (str, default 'bold') – Font weight for labels

  • fontfamily (str, default 'sans-serif') – Font family for labels

Examples

>>> labeler = PanelLabeler(fontsize_pt=10, location='top_left')
>>> fig, axes = plt.subplots(2, 2)
>>> for idx, (ax, label) in enumerate(zip(axes.flat, ['A', 'B', 'C', 'D'])):
...     labeler.add_label(ax, label, dpi=300)
>>> plt.show()
__init__(fontsize_pt=12, location='top_left', offset=None, fontweight='bold', fontfamily='sans-serif')[source]

Initialize PanelLabeler.

Parameters:
  • fontsize_pt (float, default 12) – Font size for labels in points.

  • location ({'top_left', 'top_right', 'bottom_left', 'bottom_right'}, default 'top_left') – Location for panel labels.

  • offset (tuple of float, optional) – Offset for label position in axes coordinates (x, y). If None, uses location-specific defaults.

  • fontweight (str, default 'bold') – Font weight for labels.

  • fontfamily (str, default 'sans-serif') – Font family for labels.

add_label(ax, label, dpi=300, **text_kwargs)[source]

Add a panel label to an axes.

Parameters:
  • ax (matplotlib.axes.Axes) – The axes to label

  • label (str) – Label text (e.g., ‘A’, ‘B’, ‘C’)

  • dpi (int, default 300) – DPI of the figure (used for size calculations)

  • **text_kwargs – Additional keyword arguments passed to ax.text()

Return type:

None

Examples

>>> fig, ax = plt.subplots()
>>> labeler = PanelLabeler()
>>> labeler.add_label(ax, 'A', dpi=300)
add_labels_to_dict(axes_dict, dpi=300, **text_kwargs)[source]

Add labels to all axes in a dictionary.

Convenience method for labeling all panels in a PanelLayout figure.

Parameters:
  • axes_dict (dict of matplotlib.axes.Axes) – Dictionary mapping panel names to axes

  • dpi (int, default 300) – DPI of the figure

  • **text_kwargs – Additional keyword arguments passed to ax.text()

Return type:

None

Examples

>>> layout = PanelLayout(units='cm', dpi=300)
>>> layout.add_panel('A', size=(8, 6))
>>> layout.add_panel('B', size=(8, 6))
>>> layout.set_grid(rows=1, cols=2)
>>> fig, axes = layout.create_figure()
>>> labeler = PanelLabeler()
>>> labeler.add_labels_to_dict(axes, dpi=layout.dpi)
driada.utils.publication.format_panel_label(index, style='upper')[source]

Format panel label from index.

Parameters:
  • index (int) – Zero-based panel index

  • style ({'upper', 'lower', 'number'}, default 'upper') – Label style: - ‘upper’: Uppercase letters (A, B, C, …) - ‘lower’: Lowercase letters (a, b, c, …) - ‘number’: Numbers (1, 2, 3, …)

Returns:

Formatted panel label

Return type:

str

Examples

>>> format_panel_label(0, 'upper')
'A'
>>> format_panel_label(1, 'lower')
'b'
>>> format_panel_label(2, 'number')
'3'
driada.utils.publication.to_inches(value, from_units)[source]

Convert from user units to inches (matplotlib’s native unit).

Parameters:
  • value (float or tuple of float) – Value(s) to convert

  • from_units ({'cm', 'inches'}) – Source units

Returns:

Value(s) in inches

Return type:

float or tuple of float

Raises:

ValueError – If units are not ‘cm’ or ‘inches’

driada.utils.publication.from_inches(value, to_units)[source]

Convert from inches to user units.

Parameters:
  • value (float or tuple of float) – Value(s) in inches

  • to_units ({'cm', 'inches'}) – Target units

Returns:

Value(s) in target units

Return type:

float or tuple of float

Raises:

ValueError – If units are not ‘cm’ or ‘inches’

GIF Creation

driada.utils.gif.erase_all(path, signature='', ext='.png')[source]

Delete all files in a directory matching signature and extension.

Searches for files in the specified directory that contain the given signature string in their filename and have the specified extension, then deletes them. If the directory doesn’t exist, returns silently.

Parameters:
  • path (str) – Directory path to search for files.

  • signature (str, optional) – String that must be contained in filename (default: ‘’).

  • ext (str, optional) – File extension to match, including dot (default: ‘.png’).

Raises:

OSError – If file deletion fails due to permissions or file being in use.

Notes

This function is typically used to clean up temporary image files before creating new visualizations. It will not raise an error if the directory doesn’t exist.

Examples

>>> # Delete all PNG files in a directory
>>> erase_all('/tmp/images', ext='.png')
>>> # Delete only files containing 'temp' in the name
>>> erase_all('/tmp/images', signature='temp', ext='.jpg')

See also

save_image_series

Save multiple figures to disk.

driada.utils.gif.save_image_series(path, figures, im_ext='png')[source]

Save a series of matplotlib figures to disk.

Saves each figure in the provided list to the specified directory. Creates the directory if it doesn’t exist. Figures are named using their suptitle if available, otherwise using a numbered sequence. Each figure is closed after saving to free memory.

Parameters:
  • path (str) – Directory path where images will be saved.

  • figures (list) – List of matplotlib figure objects.

  • im_ext (str, optional) – Image extension without dot (default: ‘png’).

Raises:
  • OSError – If directory creation fails or file cannot be saved.

  • AttributeError – If an element in figures list is not a valid matplotlib figure.

Notes

This function displays a progress bar using tqdm while saving figures. All figures are automatically closed after saving to prevent memory leaks. If a figure has a suptitle, it will be used as the filename.

Examples

>>> import matplotlib.pyplot as plt
>>> # Create multiple figures
>>> figs = []
>>> for i in range(5):
...     fig, ax = plt.subplots()
...     _ = ax.plot([1, 2, 3], [i, i+1, i+2])
...     _ = fig.suptitle(f'Plot_{i}')
...     figs.append(fig)
>>> # save_image_series('/tmp/plots', figs, im_ext='png')

See also

create_gif_from_image_series

Create animated GIF from saved images.

erase_all

Clean up image files.

driada.utils.gif.create_gif_from_image_series(path, signature, gifname, erase_prev=True, im_ext='png', duration=0.2)[source]

Create an animated GIF from a series of images.

Searches for images in the specified directory that contain the signature string in their filename, sorts them alphabetically, and combines them into an animated GIF. The GIF is saved in a ‘GIFs’ subdirectory which is created automatically if it doesn’t exist.

Parameters:
  • path (str) – Directory containing the source images.

  • signature (str) – String that must be contained in image filenames to include.

  • gifname (str) – Name for the output GIF file (without extension).

  • erase_prev (bool, optional) – Whether to delete matching images after creating GIF (default: True).

  • im_ext (str, optional) – Image extension to search for (default: ‘png’).

  • duration (float, optional) – Duration of each frame in seconds (default: 0.2).

Returns:

Path to the created GIF file.

Return type:

str

Raises:
  • OSError – If directory operations fail or images cannot be read/written.

  • ValueError – If no matching images are found (from imageio).

Notes

The function creates a ‘GIFs’ subdirectory within the input path to store the output GIF. Images are sorted alphabetically by filename before being added to the GIF, so proper naming (e.g., frame_0001.png, frame_0002.png) ensures correct order. A progress bar shows the image loading process.

The function handles image extensions flexibly - ‘png’ and ‘.png’ are treated the same way.

Examples

Create GIF from all PNG images containing ‘frame’ in the name:

gif_path = create_gif_from_image_series(
    '/tmp/images',
    signature='frame',
    gifname='animation',
    duration=0.5
)

Keep source images after creating GIF:

gif_path = create_gif_from_image_series(
    '/tmp/images',
    signature='plot_',
    gifname='results',
    erase_prev=False
)

See also

save_image_series

Save matplotlib figures as image series.

erase_all

Delete files matching specific criteria.