Source code for bioviz.plots.table

"""
Table utilities (bioviz)

Uses neutral `DefaultStyle`.
"""

import matplotlib.pyplot as plt
import pandas as pd

from bioviz.configs import StyledTableConfig
from bioviz.utils.plotting import resolve_font_family

__all__ = ["TablePlotter"]


def _calculate_column_widths(
    df: pd.DataFrame,
    table_width: float,
    min_fraction: float,
) -> list[float]:
    """
    Calculate proportional column widths based on max content length per column.

    Parameters
    ----------
    df : pd.DataFrame
        The data (columns include headers).
    table_width : float
        Total table width to distribute.
    min_fraction : float
        Minimum width per column as fraction of table_width.

    Returns
    -------
    list[float]
        Column widths summing to table_width.
    """
    n_cols = len(df.columns)
    if n_cols == 0:
        return []

    # Find max string length per column (header + all cell values)
    max_lengths = []
    for col in df.columns:
        header_len = len(str(col))
        cell_max = df[col].astype(str).str.len().max() if len(df) > 0 else 0
        max_lengths.append(max(header_len, cell_max, 1))  # at least 1 to avoid div by zero

    # Apply minimum width constraint
    min_width = min_fraction * table_width
    total_len = sum(max_lengths)

    # Proportional widths
    raw_widths = [(length / total_len) * table_width for length in max_lengths]

    # Enforce minimum, then redistribute excess
    widths = [max(w, min_width) for w in raw_widths]
    excess = sum(widths) - table_width

    if excess > 0:
        # Scale down proportionally (those above minimum)
        scalable = [w - min_width for w in widths]
        scalable_total = sum(scalable)
        if scalable_total > 0:
            widths = [
                min_width + (s / scalable_total) * (table_width - min_width * n_cols)
                for s in scalable
            ]

    return widths


def _normalize_column_widths(
    widths: list[float],
    table_width: float,
) -> list[float]:
    """
    Normalize explicit column widths to sum to table_width.

    If widths already sum close to table_width (within 1%), use as-is.
    Otherwise scale proportionally.

    Parameters
    ----------
    widths : list[float]
        User-provided column width fractions.
    table_width : float
        Target total width.

    Returns
    -------
    list[float]
        Normalized widths summing to table_width.
    """
    if not widths:
        return []
    total = sum(widths)
    if total == 0:
        # All zeros - fall back to equal widths
        return [table_width / len(widths)] * len(widths)
    # If already sums to ~table_width, use as-is
    if abs(total - table_width) / table_width < 0.01:
        return widths
    # Scale to table_width
    return [(w / total) * table_width for w in widths]


def _resolve_fontsize(config_value: float | int | None, rcparam_key: str) -> float:
    """
    Resolve fontsize: use config value if set, otherwise fall back to rcParams.

    Parameters
    ----------
    config_value : float | int | None
        Value from config. If None, use rcParams.
    rcparam_key : str
        Key in matplotlib.rcParams to fall back to.

    Returns
    -------
    float
        The resolved fontsize.
    """
    if config_value is not None:
        return float(config_value)
    rc_val = plt.rcParams.get(rcparam_key, 12)
    try:
        return float(rc_val)
    except (TypeError, ValueError):
        # rcVal might be a named size like 'large' — convert via FontProperties
        try:
            from matplotlib.font_manager import FontProperties

            fp = FontProperties(size=rc_val)
            return float(fp.get_size_in_points())
        except Exception:
            return 12.0


def generate_styled_table(
    df: pd.DataFrame,
    config: StyledTableConfig,
    ax: plt.Axes | None = None,
) -> plt.Figure | None:
    """
    Generate a styled matplotlib table figure from a DataFrame using `StyledTableConfig`.

    Args:
       df: pandas DataFrame containing the table data (rows x columns).
       config: `StyledTableConfig` (pydantic) controlling visual aspects such as
          title, font sizes, header/body colors, row heights, table width, and
          automatic shrinking behavior when many rows are present.
       ax: Optional matplotlib `Axes` to draw the table into; if omitted a new
          figure and axes will be created and returned.

    Returns:
       A matplotlib `Figure` (or None if input DataFrame is empty).
    """

    if df.empty:
        print("DataFrame is empty.")
        return None

    # Resolve fontsizes from config or rcParams
    title_font_size = _resolve_fontsize(config.title_font_size, "axes.titlesize")
    header_font_size = _resolve_fontsize(config.header_font_size, "axes.labelsize")
    cell_font_size = _resolve_fontsize(config.cell_font_size, "font.size")

    created_fig = None
    if ax is None:
        # Use a consistent base figure size; avoid implicit scaling surprises
        created_fig, ax = plt.subplots()
        try:
            created_fig.patch.set_facecolor("white")
            created_fig.patch.set_alpha(0.0)
        except Exception:
            pass
    fig = ax.figure

    ax.axis("off")

    header_height = (
        config.header_row_height if config.header_row_height is not None else config.row_height
    )

    # Reduce margins so saved output is tight around the table
    table_left = 0.0
    table_bottom = 0.0

    # Calculate column widths:
    # Priority: explicit column_widths > auto_column_widths > equal (None)
    col_widths = None
    if config.column_widths is not None:
        # Manual: use explicit fractions, scale to table_width if needed
        col_widths = _normalize_column_widths(config.column_widths, config.table_width)
    elif config.auto_column_widths:
        # Auto: proportional to content length (header or cell, whichever longer)
        col_widths = _calculate_column_widths(
            df,
            config.table_width,
            config.min_column_width_fraction,
        )
    # else: None = equal widths (matplotlib default)

    table = ax.table(
        cellText=df.values.tolist(),
        colLabels=df.columns.tolist(),
        colWidths=col_widths,
        cellLoc="center",
        edges="closed",
        bbox=[
            table_left,
            table_bottom,
            config.table_width,
            header_height + (config.row_height * df.shape[0]),
        ],
    )

    if config.absolute_font_size:
        table.auto_set_font_size(False)

    row_heights = {}
    for (row, _), cell in table.get_celld().items():
        text = cell.get_text().get_text()
        # Optionally ignore embedded newlines to keep uniform row heights
        num_lines = text.count("\n") + 1 if config.respect_newlines else 1
        is_header = row == 0
        base_height = header_height if is_header else config.row_height
        if is_header:
            needed_height = base_height
        else:
            needed_height = base_height * num_lines * config.row_height_multiplier
        row_heights[row] = max(row_heights.get(row, base_height), needed_height)

    total_height = sum(row_heights.values())

    # When there are many rows, cap the overall table height to avoid oversized single-line cells
    scale_factor = 1.0
    if (
        config.auto_shrink_total_height
        and df.shape[0] >= config.shrink_row_threshold
        and total_height >= config.max_total_height
    ):
        scale_factor = config.max_total_height / total_height

    effective_total_height = total_height * scale_factor
    table._bbox = [table_left, table_bottom, config.table_width, effective_total_height]

    header_family = config.header_font_family or resolve_font_family()
    body_family = config.body_font_family or resolve_font_family()
    title_family = header_family or body_family or resolve_font_family()

    for (row, _), cell in table.get_celld().items():
        text_obj = cell.get_text()
        is_header = row == 0
        cell.set_edgecolor(config.edge_color)
        font_size = header_font_size if is_header else cell_font_size
        if len(text_obj.get_text()) > config.max_chars:
            font_size = max(8, font_size - config.shrink_by)
        text_obj.set_fontsize(font_size)
        family = header_family if is_header else body_family
        if family:
            text_obj.set_fontname(family)
        text_obj.set_fontweight(config.header_font_weight if is_header else config.body_font_weight)
        text_obj.set_color(config.header_text_color if is_header else "black")
        text_obj.set_ha("center")
        text_obj.set_va("center")
        # Keep row height fractions normalized; bbox enforces total height
        cell.set_height(row_heights[row] / total_height)
        if is_header:
            cell.set_facecolor(config.header_bg_color)
        else:
            cell.set_facecolor(config.row_colors[row % 2])

    if config.title:
        ax.text(
            0.5,
            table_bottom + effective_total_height + 0.05,
            config.title,
            ha="center",
            va="bottom",
            fontsize=title_font_size,
            fontweight="bold",
            fontfamily=title_family,
            transform=ax.transAxes,
        )

    # Tighten layout and adjust figure size to content. Use sensible minimums so very
    # small configs still produce a readable figure.
    padding_w, padding_h = 0.4, 0.4
    content_width = config.table_width
    content_height = effective_total_height
    min_width, min_height = 2.0, 1.5
    fig.set_size_inches(
        max(content_width + padding_w, min_width),
        max(content_height + padding_h, min_height),
    )
    fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
    if created_fig is not None:
        created_fig.tight_layout()

    return created_fig or fig


[docs] class TablePlotter: """Stateful wrapper for styled tables. Construct with `(df, config)` where `config` is `StyledTableConfig` or a dict acceptable to it. Delegates rendering to `generate_styled_table`. """ def __init__(self, df: pd.DataFrame, config: StyledTableConfig | dict): if isinstance(config, dict): config = StyledTableConfig(**config) self.df = df.copy() self.config = config self.fig: plt.Figure | None = None self.ax: plt.Axes | None = None
[docs] def set_data(self, df: pd.DataFrame) -> "TablePlotter": self.df = df.copy() return self
[docs] def update_config(self, **kwargs) -> "TablePlotter": for k, v in kwargs.items(): try: setattr(self.config, k, v) except Exception: continue return self
[docs] def plot(self, ax: plt.Axes | None = None) -> tuple[plt.Figure | None, plt.Axes | None]: """Render the styled table and store `fig, ax` on the instance.""" self.fig = generate_styled_table(self.df, self.config, ax=ax) if self.fig is None: self.ax = None else: try: self.ax = self.fig.axes[0] if self.fig.axes else None except Exception: self.ax = None return self.fig, self.ax
[docs] def save(self, path: str, **save_kwargs) -> None: from pathlib import Path if self.fig is None: raise RuntimeError("No figure available; call .plot() first") Path(path).parent.mkdir(parents=True, exist_ok=True) self.fig.savefig(path, **save_kwargs)
[docs] def close(self) -> None: try: if self.fig is not None: plt.close(self.fig) finally: self.fig = None self.ax = None