Source code for bioviz.plots.forest

"""Forest plot visualization for hazard ratios.

Provides a plotter class for generating forest plots that display hazard ratios
with confidence intervals from survival analysis (Cox PH models).

Example
-------
>>> from bioviz.configs import ForestPlotConfig
>>> from bioviz.plots import ForestPlotter
>>> cfg = ForestPlotConfig(hr_col="hr", ci_lower_col="ci_lower", ci_upper_col="ci_upper")
>>> plotter = ForestPlotter(hr_df, cfg)
>>> fig, ax = plotter.plot()
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from ..configs.forest_cfg import ForestPlotConfig

__all__ = ["ForestPlotter"]


def _resolve_fontsize(config_value: int | None, rcparam_key: str, default: float = 10) -> float:
    """Return config value if set, else fall back to rcParams or default.

    Handles matplotlib rcParams that may be strings like 'medium', 'large', etc.
    """
    if config_value is not None:
        return float(config_value)
    rc_value = plt.rcParams.get(rcparam_key, default)
    try:
        return float(rc_value)
    except (ValueError, TypeError):
        # rcParams value is a string like 'medium' - use default
        return float(default)


[docs] class ForestPlotter: """Generate forest plots for hazard ratio visualization. Parameters ---------- data : pd.DataFrame DataFrame containing HR data with columns for HR, CI bounds, labels, p-values. config : ForestPlotConfig Configuration object specifying plot options. Attributes ---------- data : pd.DataFrame config : ForestPlotConfig """ def __init__(self, data: pd.DataFrame, config: ForestPlotConfig) -> None: self.data = data.copy() self.config = config self._validate_data() def _validate_data(self) -> None: """Validate that required columns exist.""" cfg = self.config required = [cfg.hr_col, cfg.ci_lower_col, cfg.ci_upper_col, cfg.label_col] missing = [c for c in required if c not in self.data.columns] if missing: raise ValueError(f"Missing required columns: {missing}")
[docs] def prepare_data(self) -> pd.DataFrame: """Prepare and order data for plotting. Returns ------- pd.DataFrame Cleaned and ordered dataframe ready for plotting. """ cfg = self.config df = self.data.dropna(subset=[cfg.hr_col, cfg.ci_lower_col, cfg.ci_upper_col]) if df.empty: raise ValueError("No valid data rows after removing NaN values") # Apply category ordering if cfg.category_order and cfg.variable_col and cfg.variable_col in df.columns: ordered_dfs = [] ordered_vars = list(cfg.category_order.keys())[::-1] remaining = [v for v in df[cfg.variable_col].unique() if v not in ordered_vars] for var in ordered_vars + remaining: var_df = df[df[cfg.variable_col] == var].copy() if var in cfg.category_order: order_list = cfg.category_order[var] var_df[cfg.label_col] = pd.Categorical( var_df[cfg.label_col], categories=order_list, ordered=True ) var_df = var_df.sort_values(cfg.label_col) var_df[cfg.label_col] = var_df[cfg.label_col].astype(str) ordered_dfs.append(var_df) df = pd.concat(ordered_dfs, ignore_index=True) elif cfg.category_order and cfg.label_col in df.columns: order_list = next(iter(cfg.category_order.values()), []) if order_list: df[cfg.label_col] = pd.Categorical( df[cfg.label_col], categories=order_list, ordered=True ) df = df.sort_values(cfg.label_col) df[cfg.label_col] = df[cfg.label_col].astype(str) # Reverse for matplotlib (y=0 at bottom) if cfg.variable_col and cfg.variable_col in df.columns: df = ( df.groupby(cfg.variable_col, sort=False, group_keys=False) .apply(lambda g: g.iloc[::-1]) .reset_index(drop=True) ) else: df = df.iloc[::-1].reset_index(drop=True) return df
# Alias for backward compat during transition _prepare_data = prepare_data
[docs] def compute_y_positions(self, df: pd.DataFrame) -> np.ndarray: """Compute y-positions with optional section gaps. Parameters ---------- df : pd.DataFrame Prepared dataframe from prepare_data(). Returns ------- np.ndarray Array of y-coordinates for each row. """ cfg = self.config n_rows = len(df) y_positions = np.arange(n_rows, dtype=float) if cfg.variable_col and cfg.variable_col in df.columns and cfg.section_gap != 0.0: current_var = None cumulative = 0.0 for i, row in df.iterrows(): var = row[cfg.variable_col] if var != current_var and current_var is not None: cumulative += cfg.section_gap y_positions[i] += cumulative current_var = var return y_positions
# Alias for backward compat during transition _compute_y_positions = compute_y_positions def _get_colors(self, df: pd.DataFrame) -> tuple[list[str], list[str]]: """Determine CI bar and marker colors based on significance.""" cfg = self.config n = len(df) ci_colors = [] marker_colors = [] sig_color = cfg.color_significant nonsig_color = cfg.color_nonsignificant marker_sig = cfg.marker_color_significant or sig_color marker_nonsig = cfg.marker_color_nonsignificant or nonsig_color if cfg.pvalue_col in df.columns: for _, row in df.iterrows(): pval = row[cfg.pvalue_col] if pd.notna(pval) and pval < cfg.alpha_threshold: ci_colors.append(sig_color) marker_colors.append(marker_sig) else: ci_colors.append(nonsig_color) marker_colors.append(marker_nonsig) else: ci_colors = [nonsig_color] * n marker_colors = [marker_nonsig] * n return ci_colors, marker_colors def _set_xlim(self, ax, df: pd.DataFrame) -> None: """Set x-axis limits.""" cfg = self.config if cfg.xlim is not None: if cfg.center_around_null and cfg.log_scale: log_min = np.log10(cfg.xlim[0]) log_max = np.log10(cfg.xlim[1]) max_dist = max(abs(log_min), abs(log_max)) ax.set_xlim(10 ** (-max_dist), 10**max_dist) else: ax.set_xlim(cfg.xlim) else: all_vals = pd.concat([df[cfg.ci_lower_col], df[cfg.ci_upper_col]]) all_vals = all_vals.replace([np.inf, -np.inf], np.nan) all_vals = all_vals[all_vals > 0].dropna() if len(all_vals) == 0: ax.set_xlim(0.1, 10) else: min_v, max_v = all_vals.min(), all_vals.max() if cfg.log_scale: if cfg.center_around_null: max_dist = max(abs(np.log10(min_v)), abs(np.log10(max_v))) * 1.2 ax.set_xlim(10 ** (-max_dist), 10**max_dist) else: log_range = np.log10(max_v) - np.log10(min_v) ax.set_xlim( 10 ** (np.log10(min_v) - 0.2 * log_range), 10 ** (np.log10(max_v) + 0.2 * log_range), ) else: rng = max_v - min_v ax.set_xlim(max(0, min_v - 0.1 * rng), max_v + 0.1 * rng) def _add_stats_table( self, ax, df: pd.DataFrame, y_positions: np.ndarray, fontsize: float ) -> None: """Add HR/CI/p-value table on the right side.""" cfg = self.config if cfg.pvalue_col not in df.columns: return table_x = cfg.stats_table_x_position col_spacing = max(cfg.stats_table_col_spacing, fontsize * 0.015) y_min, y_max = ax.get_ylim() y_range = y_max - y_min def data_to_axes_y(y_data): return (y_data - y_min) / y_range # Header header_y = data_to_axes_y(len(df) - 1 + 0.8) for i, txt in enumerate(["HR", "95% CI", "p-value"]): ax.text( table_x + i * col_spacing, header_y, txt, transform=ax.transAxes, fontsize=fontsize, fontweight="bold", ha="center", va="center", ) # Data rows for i, (_, row) in enumerate(df.iterrows()): hr_str = f"{row[cfg.hr_col]:.2f}" ci_str = f"({row[cfg.ci_lower_col]:.2f}, {row[cfg.ci_upper_col]:.2f})" pval = row[cfg.pvalue_col] pval_str = ( "<0.001" if pd.notna(pval) and pval < 0.001 else (f"{pval:.4f}" if pd.notna(pval) else "N/A") ) row_y = data_to_axes_y(y_positions[i]) for j, txt in enumerate([hr_str, ci_str, pval_str]): ax.text( table_x + j * col_spacing, row_y, txt, transform=ax.transAxes, fontsize=fontsize, ha="center", va="center", ) def _add_section_separators_and_labels( self, ax, df: pd.DataFrame, y_positions: np.ndarray, ytick_fs: float ) -> None: """Add section separators and labels for multi-variable plots.""" cfg = self.config if not cfg.variable_col or cfg.variable_col not in df.columns: return unique_vars = df[cfg.variable_col].unique() if len(unique_vars) <= 1: return # Build section ranges var_ranges: dict[Any, dict[str, Any]] = {} for i, (_, row) in enumerate(df.iterrows()): var = row[cfg.variable_col] if var not in var_ranges: var_ranges[var] = {"indices": []} var_ranges[var]["indices"].append(i) # Compute min/max y for each section section_bounds = [] for var, info in var_ranges.items(): indices = info["indices"] min_y = y_positions[min(indices, key=lambda x: y_positions[x])] max_y = y_positions[max(indices, key=lambda x: y_positions[x])] var_ranges[var]["min_y"] = min_y var_ranges[var]["max_y"] = max_y var_ranges[var]["label_y"] = max_y section_bounds.append((min_y, var)) section_bounds.sort(key=lambda x: x[0]) # Separator lines if cfg.show_section_separators: for i in range(1, len(section_bounds)): cur_min = section_bounds[i][0] prev_var = section_bounds[i - 1][1] prev_max = var_ranges[prev_var]["max_y"] sep_y = (cur_min + prev_max) / 2 ax.axhline( y=sep_y, color=cfg.section_separator_color, linestyle="-", linewidth=1, alpha=cfg.section_separator_alpha, ) # Section labels for var, info in var_ranges.items(): label = (cfg.section_labels or {}).get(var, str(var)) ax.text( cfg.section_label_x_position, info["label_y"], label, transform=ax.get_yaxis_transform(), fontsize=ytick_fs + 1, fontweight="bold", va="center", ha="right", )
[docs] def plot( self, ax=None, fig=None, output_path: str | Path | None = None, ) -> tuple[Any, Any]: """Generate the forest plot. Parameters ---------- ax : Axes, optional Existing axes; if None, a new figure/axes is created. fig : Figure, optional Existing figure. output_path : str or Path, optional Path to save the figure. Returns ------- fig : Figure ax : Axes """ cfg = self.config df = self._prepare_data() # Font sizes ytick_fs = _resolve_fontsize(cfg.ytick_fontsize, "ytick.labelsize", 10) xtick_fs = _resolve_fontsize(cfg.xtick_fontsize, "xtick.labelsize", 10) xlabel_fs = _resolve_fontsize(cfg.xlabel_fontsize, "axes.labelsize", 11) title_fs = _resolve_fontsize(cfg.title_fontsize, "axes.titlesize", 12) stats_fs = _resolve_fontsize(cfg.stats_fontsize, "font.size", 9) # Create figure if ax is None or fig is None: fig, ax = plt.subplots(figsize=cfg.figsize) y_positions = self._compute_y_positions(df) ci_colors, marker_colors = self._get_colors(df) # Plot error bars and markers for i, (_, row) in enumerate(df.iterrows()): hr = row[cfg.hr_col] ci_lo = row[cfg.ci_lower_col] ci_hi = row[cfg.ci_upper_col] y = y_positions[i] # Error bar ax.plot( [ci_lo, ci_hi], [y, y], color=ci_colors[i], linewidth=cfg.linewidth, solid_capstyle="round", ) # Caps if cfg.show_caps: cap_h = cfg.capsize * 0.01 for x in [ci_lo, ci_hi]: ax.plot( [x, x], [y - cap_h, y + cap_h], color=ci_colors[i], linewidth=cfg.linewidth, ) # Marker ax.scatter( hr, y, s=cfg.marker_size**2, color=marker_colors[i], zorder=3, edgecolors="white", linewidths=0.5, marker=cfg.marker_style, ) # Reference line if cfg.show_reference_line: ax.axvline( x=1, color=cfg.reference_line_color, linestyle=cfg.reference_line_style, linewidth=cfg.reference_line_width, alpha=0.7, zorder=1, ) # Scale and limits if cfg.log_scale: ax.set_xscale("log") self._set_xlim(ax, df) # X-ticks if cfg.xticks: ax.set_xticks(cfg.xticks) fmt = "{:.2g}" if cfg.log_scale else "{:.2f}" ax.set_xticklabels([fmt.format(x) for x in cfg.xticks], fontsize=xtick_fs) else: ax.tick_params(axis="x", labelsize=xtick_fs) # Y-axis ax.set_yticks(y_positions) labels = df[cfg.label_col].astype(str).tolist() if cfg.reference_col and cfg.reference_col in df.columns: labels = [ f"{lbl} (vs {row[cfg.reference_col]})" if pd.notna(row[cfg.reference_col]) else lbl for lbl, (_, row) in zip(labels, df.iterrows(), strict=True) ] ax.set_yticklabels(labels, fontsize=ytick_fs) if not cfg.show_yticks: ax.tick_params(axis="y", length=0) # Y-limits if len(y_positions) > 0: ax.set_ylim(y_positions.min() - cfg.y_margin, y_positions.max() + cfg.y_margin) # Labels/title ax.set_xlabel(cfg.xlabel, fontsize=xlabel_fs, fontweight="bold") ax.set_ylabel("") if cfg.title: ax.set_title(cfg.title, fontsize=title_fs, fontweight="bold", pad=20) # Grid if cfg.show_grid: ax.grid(axis="x", alpha=0.3, linestyle=":", linewidth=0.8) ax.set_axisbelow(True) # Spines ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) if not cfg.show_y_spine: ax.spines["left"].set_visible(False) # Stats table if cfg.show_stats_table: self._add_stats_table(ax, df, y_positions, stats_fs) # Section separators/labels self._add_section_separators_and_labels(ax, df, y_positions, ytick_fs) # Save if output_path: fig.savefig(output_path, bbox_inches="tight", dpi=300) return fig, ax