from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Literal
import matplotlib.pyplot as plt
from pydantic import BaseModel, Field, field_validator
[docs]
class VolcanoConfig(BaseModel):
# ------ Data Columns ------
# Required: names of the dataframe columns to use for plotting
x_col: str = Field(..., description="Name of the x-axis column (e.g. log2_or)")
y_col: str = Field(..., description="Name of the y-axis column (e.g. p_adj)")
# Label col not required; if provided, used to match labels to points
# if no label_col is provided, labels are matched by point index
label_col: str | None = None
# ------ Selection & Thresholds ------
values_to_label: list[str] | None = None
additional_values_to_label: list[str] | None = None
label_mode: Literal[
"auto",
"sig",
"sig_and_thresh",
"thresh",
"sig_or_thresh",
"all",
] = Field(
"auto",
description=(
"Controls how labels are selected when `values_to_label` is not provided. "
"Options: 'auto' (default: label points considered significant by `y_col_thresh`), 'sig' (y threshold only), "
"'sig_and_thresh' (require both y threshold and |x| >= `abs_x_thresh`), 'thresh' (x magnitude only), "
"'sig_or_thresh' (union), and 'all' (label every point)."
),
)
y_col_thresh: float = Field(
0.05,
description=(
"Numeric cutoff applied to `y_col` to mark significance. "
"(e.g., p-value threshold). Formerly named `sig_thresh`; "
"the name was changed to avoid implying the column must be a p-value."
),
)
abs_x_thresh: float = Field(
2.0,
description=(
"Absolute x-axis magnitude threshold used by some label/"
"color selection modes (points with |x| >= this value are considered large)."
),
)
y_thresh: float | None = Field(
None,
description=(
"Optional y-axis threshold position (in data units). "
"When set, a horizontal threshold line is drawn at this value after any configured transform."
),
)
x_thresh: Iterable[float] | None = Field(
None,
description=(
"Optional x-axis threshold line positions. Provide an iterable of numeric positions to draw vertical threshold lines."
),
)
# Threshold styling and per-axis overrides
thresh_line_color: str = Field(
"gainsboro",
description="Default color used for threshold lines (applies when per-axis override is not provided).",
)
thresh_line_style: str = Field(
"--",
description="Line style used for threshold lines (e.g. '--').",
)
thresh_line_width: float = Field(
1.0,
description="Line width used for threshold lines (applies when per-axis override is not provided).",
)
x_thresh_line_color: str | None = Field(
None,
description="Optional color override for x-axis threshold lines.",
)
x_thresh_line_style: str | None = Field(
None,
description="Optional line-style override for x-axis threshold lines.",
)
x_thresh_line_width: float | None = Field(
None,
description="Optional line-width override for x-axis threshold lines.",
)
y_thresh_line_color: str | None = Field(
None,
description="Optional color override for y-axis threshold lines.",
)
y_thresh_line_style: str | None = Field(
None,
description="Optional line-style override for y-axis threshold lines.",
)
y_thresh_line_width: float | None = Field(
None,
description="Optional line-width override for y-axis threshold lines.",
)
# ------ Coloring & Direction ------
direction_col: str | None = None
direction_colors: dict[str, str] | None = None
palette: dict[str, str] = Field(
default_factory=lambda: {
"nonsig": "gainsboro",
"sig_up": "#009E73",
"sig_down": "#D55E00",
}
)
color_mode: Literal["sig", "thresh", "sig_and_thresh", "sig_or_thresh", "all"] = Field(
"sig",
description=(
"Controls how point colors are assigned relative to thresholds/significance. "
"Options: 'sig', 'thresh', 'sig_and_thresh', 'sig_or_thresh', 'all'."
),
)
# ------ Labeling & Annotation ------
label_offset_mode: str = Field(
"fraction",
description=(
"How to interpret `label_offset`: 'fraction' interprets the value as a fraction of the x-axis span, "
"'data' treats it as raw data units, and 'axes' interprets it as an axis fraction (0..1) converted to data units."
),
)
label_offset: float = Field(
0.03,
description=(
"Default label offset used when `label_offset_mode` == 'fraction' (fraction of x-axis span)."
),
)
force_label_side_by_point_sign: bool = Field(
False,
description=(
"If True, force labels to appear on the outward side of each point: left when x<0, right when x>0. "
"When False, labels may be placed by other rules or `adjust_text`."
),
)
force_labels_adjustable: bool = Field(
False,
description=(
"When True, forced outward labels are included in the `adjust_text` pass and may be moved to avoid overlaps."
),
)
annotation_fontweight_sig: str = Field(
"bold",
description="Font weight used for labels on significant points (default: 'bold').",
)
annotation_fontweight_nonsig: str = Field(
"normal",
description="Font weight used for labels on non-significant points (default: 'normal').",
)
annotation_sig_color: str | None = None
annotation_nonsig_color: str = "#7f7f7f"
# Optional per-direction annotation colors. If provided these will be used
# for significant annotations that have a direction/group value (e.g. 'Early'/'Durable')
# and take precedence over `annotation_sig_color` when the direction maps to
# 'sig_up' or 'sig_down' semantics.
annotation_sig_up_color: str | None = None
annotation_sig_down_color: str | None = None
horiz_offset_range: tuple[float, float] = (0.02, 0.06)
vert_jitter_range: tuple[float, float] = (-0.03, 0.03)
use_adjust_text: bool = Field(
True,
description=(
"Whether to use the `adjust_text` package (if available) to tidy label positions before drawing connectors."
),
)
adjust: bool = True
# Whether to transform the y-column using -log10 (e.g., p-values -> -log10(p))
log_transform_ycol: bool = Field(
False,
description=("When True, the y column will be transformed with -log10 before plotting."),
)
# Nudging / label layout knobs used by the plotting code but optional
nudge_padding_pixels: float = Field(
6.0,
description=(
"Display-space padding used when nudging labels away from nearby markers (pixels)."
),
)
horiz_offset_range: tuple[float, float] = Field(
(0.02, 0.06),
description=("Range (lo,hi) for horizontal offset fractions used when placing labels."),
)
vert_jitter_range: tuple[float, float] = Field(
(-0.03, 0.03),
description=(
"Range (lo,hi) for vertical jitter applied to labels as fraction of axis span."
),
)
# ------ Layout & Axes ------
x_label: str | None = Field(
None,
description="Optional x-axis label (string or TeX). If not provided a sensible default is used.",
)
y_label: str | None = Field(
None,
description="Optional y-axis label (string or TeX). If not provided a sensible default is used.",
)
xlim: tuple[float, float] | None = Field(
None,
description=(
"Optional explicit x-axis limits (min, max) in data units. When provided these override automatic expansion."
),
)
ylim: tuple[float, float] | None = Field(
None,
description=(
"Optional explicit y-axis limits (min, max) in data units. When provided these override automatic expansion."
),
)
xticks: Iterable[float] | None = Field(
None,
description="Optional explicit x-tick locations (iterable of numeric values).",
)
yticks: Iterable[float] | None = Field(
None,
description="Optional explicit y-tick locations (iterable of numeric values).",
)
xtick_step: int | None = None
fontsize_sig: int = 12
fontsize_nonsig: int = 11
tick_label_fontsize: int = 16
axis_label_fontsize: int = 18
title: str | None = None
title_fontsize: int = 20
title_fontweight: str = Field(
"normal",
description="Font weight for title ('normal', 'bold', 'light', etc.).",
)
figsize: tuple[int, int] = (5, 5)
group_label_top: tuple[str, str] | None = None
group_label_kwargs: dict | None = None
# ------ Marker & Connectors ------
marker_size: float = 50.0
attach_to_marker_edge: bool = True
pad_by_marker: bool = Field(
True,
description=(
"When True, expand axis limits by the marker display radius so large markers near the edge are not clipped. "
"Set False to preserve exact axis limits (useful when caller set `xlim`/`ylim` explicitly)."
),
)
connector_color: str = "gray"
connector_width: float = 0.8
connector_color_sig: str | None = Field(
None,
description="Optional connector color for significant points (falls back to `connector_color` if None).",
)
connector_color_nonsig: str | None = Field(
None,
description="Optional connector color for non-significant points (falls back to `connector_color` if None).",
)
connector_color_left: str | None = Field(
None,
description="Optional connector color override for left-side labels.",
)
connector_color_right: str | None = Field(
None,
description="Optional connector color override for right-side labels.",
)
connector_color_sig_left: str | None = Field(
None,
description="Most-specific override: connector color for significant points on the left side.",
)
connector_color_sig_right: str | None = Field(
None,
description="Most-specific override: connector color for significant points on the right side.",
)
connector_color_nonsig_left: str | None = Field(
None,
description="Most-specific override: connector color for non-significant points on the left side.",
)
connector_color_nonsig_right: str | None = Field(
None,
description="Most-specific override: connector color for non-significant points on the right side.",
)
connector_color_use_point_color: bool = False
# ------ Execution / Explicit placements ------
ax: plt.Axes | None = None
explicit_label_positions: Any | None = Field(
None,
description=(
"Optional explicit label positions. Accepts a dict label->(x,y), an iterable of (label,(x,y)), "
"or a pandas DataFrame with columns ('label','x','y') or matching `label_col`/`x_col`/`y_col`."
),
)
explicit_label_replace: bool = Field(
True,
description=(
"When True, labels provided in `explicit_label_positions` replace automatic labeling for those labels. "
"When False, explicit positions are added alongside automatic labels."
),
)
explicit_label_adjustable: bool = Field(
False,
description=(
"When True, explicit labels participate in the `adjust_text` flow and may be moved; otherwise explicit positions are respected."
),
)
model_config = {"arbitrary_types_allowed": True}
@field_validator("palette", mode="before")
@classmethod
def _ensure_palette_keys(cls, v):
# Ensure minimal palette keys exist
v = dict(v)
v.setdefault("nonsig", "gainsboro")
if "sig_up" not in v and "sig" in v:
v.setdefault("sig_up", v.get("sig"))
v.setdefault("sig_up", "#009E73")
v.setdefault("sig_down", "#D55E00")
return v