"""
Visualization functions for RSA.
"""
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
from ..utils.plot import make_beautiful, create_default_figure
[docs]
def plot_rdm(
rdm: np.ndarray,
labels: Optional[List[str]] = None,
title: Optional[str] = None,
cmap: str = "RdBu_r",
figsize: Tuple[float, float] = (8, 7),
show_values: bool = False,
dendrogram_ratio: float = 0.2,
cbar_label: str = "Dissimilarity",
ax: Optional[plt.Axes] = None,
) -> plt.Figure:
"""
Plot a representational dissimilarity matrix with optional dendrogram.
Parameters
----------
rdm : np.ndarray
Square RDM matrix
labels : list of str, optional
Labels for each condition/item
title : str, optional
Plot title
cmap : str, default 'RdBu_r'
Colormap for the heatmap
figsize : tuple, default (8, 7)
Figure size if creating new figure
show_values : bool, default False
Whether to show numerical values in cells
dendrogram_ratio : float, default 0.2
Proportion of figure height for dendrogram (0 to disable)
cbar_label : str, default 'Dissimilarity'
Label for colorbar
ax : matplotlib.Axes, optional
Existing axes to plot on (disables dendrogram)
Returns
-------
fig : matplotlib.Figure
The figure object"""
n_items = rdm.shape[0]
if labels is None:
labels = [f"Item {i+1}" for i in range(n_items)]
has_dendrogram = False
if ax is None:
# Create figure with optional dendrogram
if dendrogram_ratio > 0:
has_dendrogram = True
# Create figure
fig = plt.figure(figsize=figsize)
# Create grid for dendrogram and heatmap
gs = fig.add_gridspec(
2,
2,
height_ratios=[dendrogram_ratio, 1 - dendrogram_ratio],
width_ratios=[dendrogram_ratio, 1 - dendrogram_ratio],
hspace=0.02,
wspace=0.02,
)
# Dendrogram axes
ax_dendro_top = fig.add_subplot(gs[0, 1])
ax_dendro_left = fig.add_subplot(gs[1, 0])
ax_main = fig.add_subplot(gs[1, 1])
# Hide the unused corner
ax_corner = fig.add_subplot(gs[0, 0])
ax_corner.axis('off')
# Compute linkage from condensed distance matrix
linkage_matrix = linkage(squareform(rdm), method="average")
# Plot dendrograms
dendro_top = dendrogram(
linkage_matrix, ax=ax_dendro_top, orientation="top", no_labels=True
)
dendro_left = dendrogram(
linkage_matrix, ax=ax_dendro_left, orientation="left", no_labels=True
)
# Completely hide dendrogram axes (not just invisible)
ax_dendro_top.axis('off')
ax_dendro_left.axis('off')
# Reorder RDM according to dendrogram
order = dendro_top["leaves"]
rdm_ordered = rdm[order][:, order]
labels_ordered = [labels[i] for i in order]
ax = ax_main
else:
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
rdm_ordered = rdm
labels_ordered = labels
else:
fig = ax.figure
rdm_ordered = rdm
labels_ordered = labels
# Plot heatmap
im = ax.imshow(rdm_ordered, cmap=cmap, aspect="auto", interpolation='nearest')
# Set ticks and labels with proper styling
ax.set_xticks(np.arange(n_items))
ax.set_yticks(np.arange(n_items))
ax.set_xticklabels(labels_ordered, rotation=45, ha="right", fontsize=11)
ax.set_yticklabels(labels_ordered, fontsize=11)
# Style tick parameters
ax.tick_params(axis='both', which='major', width=2, length=6, labelsize=11)
# Add values if requested
if show_values and n_items <= 20: # Only show values for small RDMs
for i in range(n_items):
for j in range(n_items):
text_color = "white" if rdm_ordered[i, j] > np.median(rdm_ordered) else "black"
ax.text(
j,
i,
f"{rdm_ordered[i, j]:.2f}",
ha="center",
va="center",
color=text_color,
fontsize=10,
)
# Add colorbar with proper positioning
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label(cbar_label, fontsize=12)
cbar.ax.tick_params(labelsize=10, width=2, length=4)
# Add title
if title:
ax.set_title(title, pad=20, fontsize=14, fontweight='bold')
# Add grid
ax.grid(False)
# Style spines
for spine in ax.spines.values():
spine.set_linewidth(2)
# Adjust layout to prevent overlap
if has_dendrogram:
fig.subplots_adjust(left=0.15, right=0.88, bottom=0.12, top=0.92)
else:
fig.tight_layout()
return fig
[docs]
def plot_rdm_comparison(
rdms: List[np.ndarray],
labels: Optional[List[str]] = None,
titles: Optional[List[str]] = None,
figsize: Optional[Tuple[float, float]] = None,
cmap: str = "RdBu_r",
) -> plt.Figure:
"""
Plot multiple RDMs side by side for comparison.
Parameters
----------
rdms : list of np.ndarray
List of RDM matrices to compare
labels : list of str, optional
Labels for conditions/items (same for all RDMs)
titles : list of str, optional
Title for each RDM
figsize : tuple, optional
Figure size (default based on number of RDMs)
cmap : str, default 'RdBu_r'
Colormap for the heatmaps
Returns
-------
fig : matplotlib.Figure
The figure object"""
n_rdms = len(rdms)
# Validate that all RDMs have the same shape
if n_rdms > 0:
first_shape = rdms[0].shape
for i, rdm in enumerate(rdms[1:], 1):
if rdm.shape != first_shape:
raise ValueError(
f"All RDMs must have the same shape. RDM 0 has shape {first_shape}, "
f"but RDM {i} has shape {rdm.shape}"
)
if figsize is None:
figsize = (6 * n_rdms + 1, 5)
if titles is None:
titles = [f"RDM {i+1}" for i in range(n_rdms)]
# Create figure with subplots
fig, axes = plt.subplots(1, n_rdms, figsize=figsize)
if n_rdms == 1:
axes = [axes]
# Find global min/max for consistent color scale
vmin = min(rdm.min() for rdm in rdms)
vmax = max(rdm.max() for rdm in rdms)
for i, (rdm, ax, title) in enumerate(zip(rdms, axes, titles)):
im = ax.imshow(rdm, cmap=cmap, aspect="auto", vmin=vmin, vmax=vmax)
# Set labels with proper styling
if labels is not None:
n_items = rdm.shape[0]
ax.set_xticks(np.arange(n_items))
ax.set_yticks(np.arange(n_items))
ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=12)
if i == 0: # Only show y labels on first plot
ax.set_yticklabels(labels, fontsize=12)
else:
ax.set_yticklabels([])
ax.set_title(title, fontsize=14, fontweight='bold', pad=10)
ax.tick_params(width=2, length=6)
ax.grid(False)
# Style spines
for spine in ax.spines.values():
spine.set_linewidth(2)
# Add single colorbar (tight_layout first, then colorbar to avoid overlap)
fig.tight_layout()
cbar = fig.colorbar(im, ax=axes, fraction=0.04, pad=0.05)
cbar.set_label("Dissimilarity", fontsize=12)
cbar.ax.tick_params(labelsize=10, width=2, length=4)
return fig