"""
Generate a volcano-style plot with a pydantic `VolcanoConfig`.
"""
from __future__ import annotations
import contextlib
import math
import warnings
from collections.abc import Mapping
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from bioviz.configs.volcano_cfg import VolcanoConfig
try:
from adjustText import adjust_text
except Exception:
def adjust_text(texts, *args, **kwargs): # type: ignore
return None
def _internal_resolve_values(df: pd.DataFrame, cfg: VolcanoConfig) -> list[str]:
# If caller provided exact values to label, honor that
if cfg.values_to_label:
base = list(cfg.values_to_label)
# Warn about any requested labels that are not present in the DataFrame
try:
if cfg.label_col and cfg.label_col in df.columns:
available = set(df[cfg.label_col].astype(str).tolist())
else:
available = set(df.index.astype(str).tolist())
missing = [v for v in base if v not in available]
if missing:
warnings.warn(
f"Requested labels not found in DataFrame: {missing}",
UserWarning,
stacklevel=2,
)
if cfg.additional_values_to_label:
# Warn about missing additions, but only append the valid ones
add = list(cfg.additional_values_to_label)
missing_add = [v for v in add if v not in available]
if missing_add:
warnings.warn(
f"Additional requested labels not found in DataFrame: {missing_add}",
UserWarning,
stacklevel=2,
)
valid_add = [v for v in add if v in available]
base = list(dict.fromkeys(base + valid_add))
except Exception:
pass
return base
# Determine label source (explicit column or index fallback)
if cfg.label_col and cfg.label_col in df.columns:
labels_series = df[cfg.label_col].astype(str)
else:
# fallback: create a Series from the DataFrame index so callers
# can use `.loc` consistently later on
labels_series = pd.Series(df.index.astype(str), index=df.index)
# Warn the user if no explicit label column or explicit labels were
# provided — the plotting code will fall back to using the DataFrame
# index as labels which is often unintentional.
if not (cfg.label_col and cfg.label_col in df.columns) and not cfg.values_to_label:
with contextlib.suppress(Exception):
warnings.warn(
"No `label_col` found and no `values_to_label` provided; "
"labels will be taken from the DataFrame index. "
"If you intended to label using a column, set `cfg.label_col` or "
"provide `values_to_label`.",
UserWarning,
stacklevel=2,
)
# Build a significance mask using the configured `y_col` and `y_col_thresh`.
sig_mask = pd.Series(False, index=df.index)
try:
y_thresh = getattr(cfg, "y_col_thresh", None)
if y_thresh is not None and cfg.y_col and cfg.y_col in df.columns:
sig_mask = df[cfg.y_col].astype(float).fillna(1.0) <= y_thresh
except Exception:
sig_mask = pd.Series(False, index=df.index)
eff_mask = (
df[cfg.x_col].abs() >= cfg.abs_x_thresh
if cfg.x_col in df.columns
else pd.Series(False, index=df.index)
)
# label_mode controls which points are chosen when `values_to_label` is not provided
mode = getattr(cfg, "label_mode", "auto")
if mode == "all":
base = labels_series.tolist()
elif mode == "sig":
base = labels_series.loc[sig_mask].tolist()
elif mode == "thresh":
base = labels_series.loc[eff_mask].tolist()
elif mode == "sig_and_thresh":
base = labels_series.loc[sig_mask & eff_mask].tolist()
elif mode == "sig_or_thresh":
base = labels_series.loc[sig_mask | eff_mask].tolist()
else:
# 'auto' (and any unknown value) defaults to the intersection
# of significance and magnitude — label points that meet both.
mask = sig_mask & eff_mask
base = labels_series.loc[mask].tolist()
if cfg.additional_values_to_label:
available = set(labels_series.tolist())
valid_add = [g for g in cfg.additional_values_to_label if g in available]
base = list(dict.fromkeys(base + valid_add))
return base
def resolve_labels(df: pd.DataFrame, cfg: VolcanoConfig) -> list[str]:
"""Return the final list of labels `plot_volcano` will use.
This helper mirrors the internal selection logic, including:
- honoring `values_to_label` and `additional_values_to_label`,
- applying `label_mode` when `values_to_label` isn't provided,
- excluding explicit placements when `explicit_label_replace` is True
(so callers can see the de-duplicated final set used for auto-labeling).
Useful for debugging or UI workflows where you want to preview labels
before re-rendering.
"""
# Start from the internal resolved list (this handles values_to_label/additional)
base = _internal_resolve_values(df, cfg)
# If explicit_label_positions are present and explicit_label_replace=True,
# those explicit labels are removed from the auto-label set inside
# plot_volcano; reflect that behavior here so the returned list matches
# what will actually be auto-labeled.
explicit_map = {}
if getattr(cfg, "explicit_label_positions", None) is not None:
try:
elp = cfg.explicit_label_positions
if isinstance(elp, dict):
explicit_map = {str(k): v for k, v in elp.items()}
elif hasattr(elp, "columns"):
cols = [c.lower() for c in elp.columns]
if "label" in cols and ("x" in cols and "y" in cols):
for _, r in elp.iterrows():
explicit_map[str(r["label"])] = (float(r["x"]), float(r["y"]))
else:
labcol = cfg.label_col
xcol = cfg.x_col
ycol = cfg.y_col
for _, r in elp.iterrows():
try:
explicit_map[str(r[labcol])] = (
float(r[xcol]),
float(r[ycol]),
)
except Exception:
continue
else:
for it in elp:
try:
explicit_map[str(it[0])] = (float(it[1][0]), float(it[1][1]))
except Exception:
continue
except Exception:
explicit_map = {}
if explicit_map and getattr(cfg, "explicit_label_replace", True):
base = [v for v in base if v not in explicit_map]
return base
def plot_volcano(cfg: VolcanoConfig, df: pd.DataFrame) -> tuple[plt.Figure, plt.Axes]:
"""Plot a volcano using the provided `VolcanoConfig`.
This function is intentionally strict: it requires a `VolcanoConfig` and
the dataframe to plot. It uses the config for everything and performs no
backward-compatibility shims.
"""
df = df.copy()
# Resolve labels
values_to_label_resolved = _internal_resolve_values(df, cfg)
# y values (allow transformation of p-values to -log10 if requested)
# Start with raw values; we may replace with -log10(p) below.
y_vals = (
df[cfg.y_col]
if (cfg.y_col and cfg.y_col in df.columns)
else pd.Series(np.nan, index=df.index)
)
transformed_y = False
# Decide whether to transform the y-column to -log10:
# - `cfg.log_transform_ycol` True -> perform transform
# - False -> do not transform
do_transform = bool(getattr(cfg, "log_transform_ycol", False))
# Perform the -log10 transform only when explicitly requested.
if do_transform and cfg.y_col and cfg.y_col in df.columns:
try:
y_vals = -np.log10(
pd.to_numeric(df[cfg.y_col], errors="coerce").replace([np.inf, -np.inf], np.nan)
)
transformed_y = True
except Exception:
y_vals = df[cfg.y_col]
# Figure / axis
if cfg.ax is None:
fig, ax = plt.subplots(figsize=cfg.figsize)
# Make figure background transparent while keeping axes face white
with contextlib.suppress(Exception):
fig.patch.set_alpha(0.0)
with contextlib.suppress(Exception):
ax.set_facecolor("white")
else:
ax = cfg.ax
fig = ax.figure
# Helper: compute a point on the marker edge (in data coords) in the
# direction toward `target_disp` so connectors attach at the marker edge
# rather than the marker center. This uses `cfg.marker_size` (the scatter
# `s` parameter) to estimate a display-space radius.
def _marker_edge_data_point(xd: float, yd: float, target_disp: tuple[float, float]):
try:
# center in display coords
center_disp = ax.transData.transform((xd, yd))
dx = target_disp[0] - center_disp[0]
dy = target_disp[1] - center_disp[1]
norm = math.hypot(dx, dy)
if norm <= 1e-8:
return xd, yd
ux, uy = dx / norm, dy / norm
# estimate marker radius in display pixels. `cfg.marker_size` is
# passed to scatter as `s` (points^2); approximate radius in
# points as sqrt(s)/2, then convert to pixels: pixels = points * dpi/72.
r_points = math.sqrt(max(cfg.marker_size, 1.0)) / 2.0
r_pixels = r_points * fig.dpi / 72.0
edge_disp = (center_disp[0] + ux * r_pixels, center_disp[1] + uy * r_pixels)
edge_data = ax.transData.inverted().transform(edge_disp)
return float(edge_data[0]), float(edge_data[1])
except Exception:
return xd, yd
def _select_connector_color(is_sig: bool, ox: float):
"""Return connector color using hierarchical precedence:
most-specific (sign+side) -> side -> sign -> nonsig -> generic.
"""
try:
# Most specific: sign + side
if is_sig:
if ox < 0 and getattr(cfg, "connector_color_sig_left", None):
return cfg.connector_color_sig_left
if ox >= 0 and getattr(cfg, "connector_color_sig_right", None):
return cfg.connector_color_sig_right
else:
if ox < 0 and getattr(cfg, "connector_color_nonsig_left", None):
return cfg.connector_color_nonsig_left
if ox >= 0 and getattr(cfg, "connector_color_nonsig_right", None):
return cfg.connector_color_nonsig_right
# Per-side override
side_color = cfg.connector_color_left if ox < 0 else cfg.connector_color_right
if side_color:
return side_color
# Per-significance override
if is_sig and getattr(cfg, "connector_color_sig", None):
return cfg.connector_color_sig
if (not is_sig) and getattr(cfg, "connector_color_nonsig", None):
return cfg.connector_color_nonsig
# Final fallback
return cfg.connector_color
except Exception:
return cfg.connector_color
def _nudge_label_if_overlapping(text_obj, marker_x, marker_y, marker_radius_pixels=None):
try:
fig.canvas.draw()
renderer = fig.canvas.get_renderer()
bbox = text_obj.get_window_extent(renderer=renderer)
if marker_radius_pixels is None:
r_points = math.sqrt(max(cfg.marker_size, 1.0)) / 2.0
marker_radius_pixels = r_points * fig.dpi / 72.0
# Use bbox center to detect overlap, but shift the text anchor
# (respecting its horizontal/vertical alignment) in display
# coordinates so the anchor moves consistently with the visual
# position of the text.
# Consider ALL markers: find the closest marker whose display
# position is within the nudge padding and push the text away
# from that marker to avoid landing on top of other markers.
cols = [c for c in ax.collections if hasattr(c, "get_offsets")]
marker_disp_positions = []
for c in cols:
try:
offs = c.get_offsets()
for off in offs:
marker_disp_positions.append(
tuple(ax.transData.transform((off[0], off[1])))
)
except Exception:
continue
bbox_center = (bbox.x0 + bbox.width / 2.0, bbox.y0 + bbox.height / 2.0)
padding = getattr(cfg, "nudge_padding_pixels", 6.0)
closest = None
closest_dist = float("inf")
for md in marker_disp_positions:
d = math.hypot(bbox_center[0] - md[0], bbox_center[1] - md[1])
if d < closest_dist:
closest_dist = d
closest = md
if closest is not None and closest_dist < (marker_radius_pixels + padding):
# Prefer to nudge horizontally toward the side with more free space
ax_left, ax_right = ax.bbox.x0, ax.bbox.x1
space_left = bbox_center[0] - ax_left
space_right = ax_right - bbox_center[0]
horiz_dir = 1.0 if space_right >= space_left else -1.0
# compute directional vector away from the closest marker
ux = (bbox_center[0] - closest[0]) / (closest_dist + 1e-8)
uy = (bbox_center[1] - closest[1]) / (closest_dist + 1e-8)
# bias horizontal movement to preferred side
ux = horiz_dir
shift_pixels = marker_radius_pixels + padding - closest_dist
# limit vertical movement so labels remain horizontally aligned
uy = uy * 0.25
text_anchor_data = text_obj.get_position()
text_anchor_disp = ax.transData.transform(
(text_anchor_data[0], text_anchor_data[1])
)
new_anchor_disp = (
text_anchor_disp[0] + ux * shift_pixels,
text_anchor_disp[1] + uy * shift_pixels,
)
new_anchor_data = ax.transData.inverted().transform(new_anchor_disp)
text_obj.set_position((new_anchor_data[0], new_anchor_data[1]))
fig.canvas.draw()
except Exception:
return
# significance mask (use appropriate threshold when y was transformed)
sig_mask = pd.Series(False, index=df.index)
y_col_thresh = getattr(cfg, "y_col_thresh", None)
try:
# build y-based mask depending on whether we transformed the y values
if y_col_thresh is not None and cfg.y_col in df.columns:
if transformed_y:
try:
thr = -np.log10(y_col_thresh)
except Exception:
thr = None
if thr is not None:
y_mask = y_vals.fillna(0.0) >= thr
else:
y_mask = pd.Series(False, index=df.index)
else:
y_mask = df[cfg.y_col].fillna(1.0) <= y_col_thresh
# Build y-based significance mask (x-threshold is applied later
# when label_mode or color_mode requests intersection semantics).
sig_mask = y_mask
except Exception:
sig_mask = pd.Series(False, index=df.index)
# magnitude-based mask (points whose absolute x value exceeds the
# configured `abs_x_thresh`). Define here so color/label selection
# logic below can reference it regardless of control flow.
try:
abs_x_thresh = getattr(cfg, "abs_x_thresh", None)
if abs_x_thresh is not None and cfg.x_col in df.columns:
eff_mask = df[cfg.x_col].abs() >= abs_x_thresh
else:
eff_mask = pd.Series(False, index=df.index)
except Exception:
eff_mask = pd.Series(False, index=df.index)
# color selection helpers
def _choose_direction_color(val):
try:
s = str(val).lower()
except Exception:
s = ""
if any(tok in s for tok in ("down", "decrease", "loss", "neg", "-")):
return cfg.palette.get("sig_down")
if any(tok in s for tok in ("up", "increase", "gain", "pos", "+")):
return cfg.palette.get("sig_up")
return cfg.palette.get("sig_up")
# Determine which points are considered "colored" according to the
# requested `cfg.color_mode` and then map colors accordingly. This
# separates selection logic from label selection so callers can choose
# independent behaviors for coloring vs labeling.
color_mode = getattr(cfg, "color_mode", "sig")
# If the user requested 'sig' coloring but no thresholds are available
# (neither y_col_thresh nor a positive abs_x_thresh present in the
# dataframe), interpret that as a request to color all points so the
# plot isn't entirely nonsignificant by default.
try:
has_y_thresh = getattr(cfg, "y_col_thresh", None) is not None and cfg.y_col in df.columns
has_x_thresh = (
getattr(cfg, "abs_x_thresh", None) is not None
and cfg.abs_x_thresh > 0
and cfg.x_col in df.columns
)
if (color_mode == "sig") and (not has_y_thresh) and (not has_x_thresh):
color_mode = "all"
except Exception:
pass
if color_mode == "all":
color_mask = pd.Series(True, index=df.index)
elif color_mode == "sig":
color_mask = sig_mask.copy()
elif color_mode == "thresh":
color_mask = eff_mask.copy()
elif color_mode == "sig_and_thresh":
color_mask = sig_mask & eff_mask
elif color_mode == "sig_or_thresh":
color_mask = sig_mask | eff_mask
else:
color_mask = sig_mask.copy()
colors = []
if cfg.direction_col and cfg.direction_col in df.columns and cfg.direction_colors:
for i in df.index:
if not color_mask.loc[i]:
colors.append(cfg.palette.get("nonsig"))
else:
colors.append(
cfg.direction_colors.get(
df.loc[i, cfg.direction_col], cfg.palette.get("sig_up")
)
)
else:
for i in df.index:
if not color_mask.loc[i]:
colors.append(cfg.palette.get("nonsig"))
continue
if cfg.direction_col and cfg.direction_col in df.columns:
color = _choose_direction_color(df.loc[i, cfg.direction_col])
else:
try:
xv = float(df.loc[i, cfg.x_col])
except Exception:
xv = 0.0
color = cfg.palette.get("sig_up") if xv >= 0 else cfg.palette.get("sig_down")
colors.append(color)
# axis limits: compute sensible defaults but allow caller overrides via cfg.xlim/cfg.ylim
x_data_min, x_data_max = df[cfg.x_col].min(), df[cfg.x_col].max()
y_data_max = y_vals.max()
x_limit = max(4, abs(x_data_min), abs(x_data_max))
y_limit = max(8, y_data_max)
# If caller provided explicit limits, use them. Otherwise use computed defaults.
if getattr(cfg, "xlim", None) is not None:
try:
ax.set_xlim(tuple(cfg.xlim))
except Exception:
ax.set_xlim(-x_limit, x_limit)
else:
ax.set_xlim(-x_limit, x_limit)
if getattr(cfg, "ylim", None) is not None:
try:
ax.set_ylim(tuple(cfg.ylim))
except Exception:
ax.set_ylim(bottom=-0.5, top=y_limit)
else:
ax.set_ylim(bottom=-0.5, top=y_limit)
# draw threshold lines
ax.axvline(x=0.0, color="#000000", linestyle="-", linewidth=0.8, zorder=1)
if cfg.x_thresh:
for xt in cfg.x_thresh:
ax.axvline(
x=xt,
color=(cfg.x_thresh_line_color or cfg.thresh_line_color),
linestyle=(cfg.x_thresh_line_style or cfg.thresh_line_style),
linewidth=(cfg.x_thresh_line_width or cfg.thresh_line_width),
zorder=1,
)
else:
# If caller didn't provide explicit x_thresholds, draw lines
# at ±abs_x_thresh when it's set to a finite positive value.
try:
if cfg.abs_x_thresh is not None and cfg.abs_x_thresh > 0:
ax.axvline(
x=cfg.abs_x_thresh,
color=(cfg.x_thresh_line_color or cfg.thresh_line_color),
linestyle=(cfg.x_thresh_line_style or cfg.thresh_line_style),
linewidth=(cfg.x_thresh_line_width or cfg.thresh_line_width),
zorder=1,
)
ax.axvline(
x=-cfg.abs_x_thresh,
color=(cfg.x_thresh_line_color or cfg.thresh_line_color),
linestyle=(cfg.x_thresh_line_style or cfg.thresh_line_style),
linewidth=(cfg.x_thresh_line_width or cfg.thresh_line_width),
zorder=1,
)
except Exception:
pass
if cfg.y_thresh is not None:
thr_y = cfg.y_thresh
elif transformed_y and getattr(cfg, "y_col_thresh", None) is not None:
try:
thr_y = -np.log10(getattr(cfg, "y_col_thresh", None))
except Exception:
thr_y = None
else:
thr_y = None
if thr_y is not None:
ax.axhline(
y=thr_y,
color=(cfg.y_thresh_line_color or cfg.thresh_line_color),
linestyle=(cfg.y_thresh_line_style or cfg.thresh_line_style),
linewidth=(cfg.y_thresh_line_width or cfg.thresh_line_width),
zorder=1,
)
# scatter (use explicit cfg.marker_size)
sc = ax.scatter(
df[cfg.x_col],
y_vals,
c=colors,
edgecolor="black",
linewidths=0.5,
s=cfg.marker_size,
zorder=3,
)
with contextlib.suppress(Exception):
# avoid clipping markers at the axes boundary
sc.set_clip_on(False)
# build labels aggregated by coordinates
all_texts = []
forced_texts = []
forced_points = []
adjustable_texts = []
adjustable_points = []
adjustable_point_sigs = []
coord_to_labels = {}
for i, row in df.iterrows():
try:
coord = (float(row[cfg.x_col]), float(y_vals.loc[i]))
except Exception:
continue
dir_val = (
row[cfg.direction_col]
if (cfg.direction_col and cfg.direction_col in df.columns)
else None
)
# Resolve label value (use label_col if present, else use index)
try:
if cfg.label_col and cfg.label_col in df.columns:
labval = str(row[cfg.label_col])
else:
labval = str(i)
except Exception:
labval = str(i)
coord_to_labels.setdefault(coord, []).append((i, labval, bool(sig_mask.loc[i]), dir_val))
# Parse explicit label placements if provided. Support dict, iterable of
# (label, (x,y)) tuples, or a DataFrame with label/x/y columns.
explicit_map = {}
if getattr(cfg, "explicit_label_positions", None) is not None:
try:
elp = cfg.explicit_label_positions
# dict-like
if isinstance(elp, dict):
for k, v in elp.items():
try:
explicit_map[str(k)] = (float(v[0]), float(v[1]))
except Exception:
continue
# DataFrame-like
elif hasattr(elp, "columns"):
# prefer explicit 'label','x','y' columns
cols = [c.lower() for c in elp.columns]
if "label" in cols and ("x" in cols and "y" in cols):
for _, r in elp.iterrows():
try:
explicit_map[str(r["label"])] = (
float(r["x"]),
float(r["y"]),
)
except Exception:
continue
else:
# try using label_col and x_col/y_col names
labcol = cfg.label_col
xcol = cfg.x_col
ycol = cfg.y_col
for _, r in elp.iterrows():
try:
explicit_map[str(r[labcol])] = (
float(r[xcol]),
float(r[ycol]),
)
except Exception:
continue
else:
# iterable of (label,(x,y))
for it in elp:
try:
lab = str(it[0])
xy = it[1]
explicit_map[lab] = (float(xy[0]), float(xy[1]))
except Exception:
continue
except Exception:
explicit_map = {}
# Warn if explicit labels reference names not present in the DataFrame
try:
if explicit_map:
if cfg.label_col and cfg.label_col in df.columns:
available_labels = set(df[cfg.label_col].astype(str).tolist())
else:
available_labels = set(df.index.astype(str).tolist())
missing_explicit = [k for k in explicit_map if k not in available_labels]
if missing_explicit:
warnings.warn(
f"Explicit label positions reference labels not in DataFrame: {missing_explicit}",
UserWarning,
stacklevel=2,
)
except Exception:
pass
x_min, x_max = ax.get_xlim()
y_min, y_max = ax.get_ylim()
# Place any explicit labels requested by the caller.
if explicit_map:
# Optionally remove explicit labels from the auto-label set so they
# aren't duplicated. Default behavior is to replace automatic labels.
if getattr(cfg, "explicit_label_replace", True):
with contextlib.suppress(Exception):
values_to_label_resolved = [
v for v in values_to_label_resolved if v not in explicit_map
]
for lab, (lx, ly) in explicit_map.items():
# skip if outside axes
if not (x_min <= lx <= x_max and y_min <= ly <= y_max):
continue
# find if this label corresponds to a data point to inherit sig status
matched_idx = None
try:
# match by label column or index
if cfg.label_col and cfg.label_col in df.columns:
matches = df.index[df[cfg.label_col].astype(str) == lab].tolist()
else:
matches = df.index[df.index.astype(str) == lab].tolist()
matched_idx = matches[0] if matches else None
except Exception:
matched_idx = None
is_sig = False
point_color = None
dir_val = None
if matched_idx is not None:
try:
is_sig = bool(sig_mask.loc[matched_idx])
pos = list(df.index).index(matched_idx)
point_color = colors[pos]
dir_val = (
df.loc[matched_idx, cfg.direction_col]
if (cfg.direction_col and cfg.direction_col in df.columns)
else None
)
except Exception:
is_sig = False
# choose annotation color
if is_sig:
# Prefer direction-specific annotation colors based on x-value sign
dir_ann = None
if matched_idx is not None:
try:
x_val = float(df.loc[matched_idx, cfg.x_col])
if x_val < 0:
dir_ann = getattr(cfg, "annotation_sig_down_color", None)
else:
dir_ann = getattr(cfg, "annotation_sig_up_color", None)
except Exception:
dir_ann = None
ann_color = (
dir_ann or cfg.annotation_sig_color or point_color or cfg.palette.get("sig_up")
)
weight = getattr(cfg, "annotation_fontweight_sig", "bold")
fontsize = cfg.fontsize_sig
else:
ann_color = getattr(cfg, "annotation_nonsig_color", "#7f7f7f")
weight = getattr(cfg, "annotation_fontweight_nonsig", "normal")
fontsize = cfg.fontsize_nonsig
t = ax.text(
lx,
ly,
lab,
fontsize=fontsize,
fontweight=weight,
color=ann_color,
zorder=4,
clip_on=False,
)
_nudge_label_if_overlapping(t, lx, ly)
# Optionally include explicit labels in the adjust_text flow
if getattr(cfg, "explicit_label_adjustable", False):
adjustable_texts.append(t)
try:
if matched_idx is not None:
adjustable_points.append(
(
float(df.loc[matched_idx, cfg.x_col]),
float(y_vals.loc[matched_idx]),
)
)
else:
adjustable_points.append((lx, ly))
except Exception:
adjustable_points.append((lx, ly))
adjustable_point_sigs.append(is_sig)
else:
# Draw connector from marker (if we matched a point) or skip
if matched_idx is not None:
try:
ox = float(df.loc[matched_idx, cfg.x_col])
oy = float(y_vals.loc[matched_idx])
except Exception:
ox, oy = None, None
if ox is not None:
try:
if getattr(cfg, "attach_to_marker_edge", True):
label_disp = ax.transData.transform((lx, ly))
attach_x, attach_y = _marker_edge_data_point(ox, oy, label_disp)
else:
attach_x, attach_y = ox, oy
except Exception:
attach_x, attach_y = ox, oy
if getattr(cfg, "connector_color_use_point_color", False) and point_color:
conn_color = point_color
else:
conn_color = _select_connector_color(is_sig, ox)
ax.plot(
[attach_x, lx],
[attach_y, ly],
color=conn_color,
linewidth=cfg.connector_width,
alpha=0.8,
zorder=3.5,
)
# Build a group -> color map for left/right labels if possible
group_side_color = {}
try:
if cfg.direction_col and cfg.direction_col in df.columns:
# Build means for each group and assign colors strictly by sign
g_kwargs = cfg.group_label_kwargs or {}
color_val = g_kwargs.get("color", {})
color_map = color_val if isinstance(color_val, dict) else {}
means = df.groupby(cfg.direction_col)[cfg.x_col].mean()
for grp in means.index:
lab = str(grp)
# explicit override first
if lab in color_map:
group_side_color[lab] = color_map[lab]
continue
if cfg.direction_colors and lab in (cfg.direction_colors or {}):
group_side_color[lab] = cfg.direction_colors.get(lab)
continue
# assign by mean sign
try:
if means.loc[grp] < 0:
group_side_color[lab] = cfg.palette.get("sig_down")
else:
group_side_color[lab] = cfg.palette.get("sig_up")
except Exception:
group_side_color[lab] = cfg.palette.get("nonsig")
except Exception:
group_side_color = {}
for coord, items in coord_to_labels.items():
x, y = coord
if not (x_min <= x <= x_max and y_min <= y <= y_max):
continue
items = [it for it in items if it[1] in values_to_label_resolved]
sig_items = [it for it in items if it[2]]
nonsig_items = [it for it in items if not it[2]]
stacked = sig_items + nonsig_items
labels = [it[1] for it in stacked]
if not labels:
continue
text_str = "\n".join(labels)
# Determine group-based color if available
group_val = stacked[0][3] if stacked and len(stacked) and len(stacked[0]) > 3 else None
ann_color = None
if group_val is not None and str(group_val) in group_side_color:
ann_color = group_side_color[str(group_val)]
# If no group color found, fall back to the actual marker color for the representative point
point_color = None
try:
rep_idx = stacked[0][0]
# map index to position in colors list
pos = list(df.index).index(rep_idx)
point_color = colors[pos]
except Exception:
point_color = None
# Choose placement mode: forced outward by point sign, or adjustable
forced_mode = getattr(cfg, "force_label_side_by_point_sign", False)
if forced_mode:
# deterministic outward placement (no adjust_text for these)
# compute offset in data units according to `label_offset_mode`
mode = getattr(cfg, "label_offset_mode", "fraction")
raw_offset = getattr(cfg, "label_offset", 0.05)
if mode == "fraction":
x0, x1 = ax.get_xlim()
span = float(x1 - x0) if x1 != x0 else 1.0
offset_data = raw_offset * span
elif mode == "axes":
# convert axis fraction to display then to data units using a small dx
try:
disp0 = ax.transAxes.transform((0.0, 0.0))
disp1 = ax.transAxes.transform((raw_offset, 0.0))
dx_disp = disp1[0] - disp0[0]
data_dx = (
ax.transData.inverted().transform((dx_disp, 0))[0]
- ax.transData.inverted().transform((0, 0))[0]
)
offset_data = data_dx
except Exception:
offset_data = raw_offset
else:
# data mode
offset_data = raw_offset
if x < 0:
tx = x - offset_data
ha = "right"
else:
tx = x + offset_data
ha = "left"
ty = y
is_sig = bool(stacked and stacked[0][2])
if is_sig:
# Use direction-specific annotation color based on x-value sign
dir_ann_local = None
if x < 0:
dir_ann_local = getattr(cfg, "annotation_sig_down_color", None)
else:
dir_ann_local = getattr(cfg, "annotation_sig_up_color", None)
color_final = (
dir_ann_local
or cfg.annotation_sig_color
or ann_color
or point_color
or cfg.palette.get("sig_up")
)
weight = getattr(cfg, "annotation_fontweight_sig", "bold")
fontsize = cfg.fontsize_sig
else:
# Always use the configured nonsignificant annotation color
# if provided; otherwise use a medium gray.
color_final = getattr(cfg, "annotation_nonsig_color", "#7f7f7f")
weight = getattr(cfg, "annotation_fontweight_nonsig", "normal")
fontsize = cfg.fontsize_nonsig
# Compute label y so the connector from label->point is ~45 degrees
try:
# point display coords
p_disp = ax.transData.transform((x, y))
# display x of the label (same y assumed initially)
label_x_disp = ax.transData.transform((tx, y))[0]
dx_disp = label_x_disp - p_disp[0]
# aim for dy_disp ~= abs(dx_disp) to get ~45°; place label above the point
dy_disp = abs(dx_disp)
label_y_disp = p_disp[1] + dy_disp
# convert back to data coords for the y position
ty = ax.transData.inverted().transform((label_x_disp, label_y_disp))[1]
except Exception:
ty = y
t = ax.text(
tx,
ty,
text_str,
fontsize=fontsize,
fontweight=weight,
color=color_final,
ha=ha,
clip_on=False,
zorder=4,
)
# Nudge label if it overlaps its marker
_nudge_label_if_overlapping(t, x, y)
force_adjust = getattr(cfg, "force_labels_adjustable", False)
if force_adjust:
# include forced labels in the adjustable/adjust_text flow
adjustable_texts.append(t)
adjustable_points.append((x, y))
adjustable_point_sigs.append(is_sig)
else:
# straight connector from label to point
try:
# Attach connector to the marker edge in the direction of
# the label (so the connector points toward the label).
try:
if getattr(cfg, "attach_to_marker_edge", True):
label_disp = ax.transData.transform((tx, ty))
attach_x, attach_y = _marker_edge_data_point(x, y, label_disp)
else:
attach_x, attach_y = x, y
except Exception:
attach_x, attach_y = x, y
conn_color = _select_connector_color(is_sig, x)
# draw a simple straight connector line from marker edge to label
ax.plot(
[attach_x, tx],
[attach_y, ty],
color=conn_color,
linewidth=cfg.connector_width,
alpha=0.8,
zorder=3.5,
)
except Exception:
pass
forced_texts.append(t)
forced_points.append((x, y))
all_texts.append(t)
else:
# adjustable placement (subject to adjust_text)
if stacked and stacked[0][2]:
color_final = (
cfg.annotation_sig_color
or ann_color
or point_color
or cfg.palette.get("sig_up")
)
# place significant labels slightly offset so connectors can be drawn
lo, hi = getattr(cfg, "horiz_offset_range", (0.02, 0.06))
samp = np.random.uniform(lo, hi)
if getattr(cfg, "label_offset_mode", "fraction") == "fraction":
x0, x1 = ax.get_xlim()
span = float(x1 - x0) if x1 != x0 else 1.0
horiz_offset = -abs(samp * span) if x < 0 else abs(samp * span)
elif getattr(cfg, "label_offset_mode", "fraction") == "axes":
try:
disp0 = ax.transAxes.transform((0.0, 0.0))
disp1 = ax.transAxes.transform((samp, 0.0))
dx_disp = disp1[0] - disp0[0]
data_dx = (
ax.transData.inverted().transform((dx_disp, 0))[0]
- ax.transData.inverted().transform((0, 0))[0]
)
horiz_offset = -abs(data_dx) if x < 0 else abs(data_dx)
except Exception:
horiz_offset = -abs(samp) if x < 0 else abs(samp)
else:
horiz_offset = -abs(samp) if x < 0 else abs(samp)
vlo, vhi = getattr(cfg, "vert_jitter_range", (-0.03, 0.03))
vj = np.random.uniform(vlo, vhi)
if getattr(cfg, "label_offset_mode", "fraction") == "fraction":
x0, x1 = ax.get_xlim()
span = float(x1 - x0) if x1 != x0 else 1.0
vert_jitter = vj * span
else:
vert_jitter = vj
t = ax.text(
x + horiz_offset,
y + vert_jitter,
text_str,
fontsize=cfg.fontsize_sig,
fontweight=getattr(cfg, "annotation_fontweight_sig", "bold"),
color=color_final,
clip_on=False,
zorder=4,
)
else:
# Force nonsig annotation text color from config
color_final = getattr(cfg, "annotation_nonsig_color", "#7f7f7f")
ha = "right" if x < 0 else "left"
# random horizontal offset and vertical jitter (interpreted per mode)
lo, hi = getattr(cfg, "horiz_offset_range", (0.02, 0.06))
samp = np.random.uniform(lo, hi)
if getattr(cfg, "label_offset_mode", "fraction") == "fraction":
x0, x1 = ax.get_xlim()
span = float(x1 - x0) if x1 != x0 else 1.0
horiz_offset = -abs(samp * span) if x < 0 else abs(samp * span)
elif getattr(cfg, "label_offset_mode", "fraction") == "axes":
# convert axes fraction to data dx
try:
disp0 = ax.transAxes.transform((0.0, 0.0))
disp1 = ax.transAxes.transform((samp, 0.0))
dx_disp = disp1[0] - disp0[0]
data_dx = (
ax.transData.inverted().transform((dx_disp, 0))[0]
- ax.transData.inverted().transform((0, 0))[0]
)
horiz_offset = -abs(data_dx) if x < 0 else abs(data_dx)
except Exception:
horiz_offset = -abs(samp) if x < 0 else abs(samp)
else:
horiz_offset = -abs(samp) if x < 0 else abs(samp)
vlo, vhi = getattr(cfg, "vert_jitter_range", (-0.03, 0.03))
vj = np.random.uniform(vlo, vhi)
if getattr(cfg, "label_offset_mode", "fraction") == "fraction":
x0, x1 = ax.get_xlim()
span = float(x1 - x0) if x1 != x0 else 1.0
vert_jitter = vj * span
else:
vert_jitter = vj
t = ax.text(
x + horiz_offset,
y + vert_jitter,
text_str,
fontsize=cfg.fontsize_nonsig,
color=color_final,
fontweight=getattr(cfg, "annotation_fontweight_nonsig", "normal"),
ha=ha,
clip_on=False,
zorder=4,
)
_nudge_label_if_overlapping(t, x, y)
adjustable_texts.append(t)
adjustable_points.append((x, y))
adjustable_point_sigs.append(stacked and stacked[0][2])
all_texts.append(t)
# Axis labels: show transformed math-style labels when appropriate
# Axis labels (overrides allowed)
if cfg.x_label:
ax.set_xlabel(cfg.x_label)
else:
lx = cfg.x_col.lower()
if "log2" in lx or "log_2" in lx:
# Keep 'OR' non-italicized inside math mode
ax.set_xlabel(r"$\log_{2}(\mathrm{OR})$")
else:
ax.set_xlabel(cfg.x_col)
if cfg.y_label:
ax.set_ylabel(cfg.y_label)
else:
if transformed_y:
# Render the original column name literally inside math text to avoid
# interpreting underscores as subscripts (e.g. p_adj)
safe_col = cfg.y_col.replace("_", r"\_")
ax.set_ylabel(rf"$-\log_{{10}}(\text{{{safe_col}}})$")
else:
ax.set_ylabel(cfg.y_col)
# Title and font sizes
if cfg.title:
ax.set_title(cfg.title, fontsize=cfg.title_fontsize, fontweight=cfg.title_fontweight)
ax.xaxis.label.set_size(cfg.axis_label_fontsize)
ax.yaxis.label.set_size(cfg.axis_label_fontsize)
for tick in ax.xaxis.get_ticklabels() + ax.yaxis.get_ticklabels():
tick.set_fontsize(cfg.tick_label_fontsize)
# Group labels at top: infer from direction_col if not explicitly provided
if cfg.group_label_top is None and cfg.direction_col and cfg.direction_col in df.columns:
try:
means = df.groupby(cfg.direction_col)[cfg.x_col].mean().dropna()
if len(means) >= 2:
sorted_idx = means.sort_values().index.tolist()
cfg_group = (str(sorted_idx[0]), str(sorted_idx[-1]))
elif len(means) == 1:
cfg_group = (str(means.index[0]), "")
else:
cfg_group = None
except Exception:
unique_vals = list(pd.Series(df[cfg.direction_col].astype(str)).unique())
if len(unique_vals) >= 2:
cfg_group = (unique_vals[0], unique_vals[1])
elif len(unique_vals) == 1:
cfg_group = (unique_vals[0], "")
else:
cfg_group = None
else:
cfg_group = cfg.group_label_top
if cfg_group:
try:
left_label, right_label = cfg_group
g_kwargs = cfg.group_label_kwargs or {}
color_val = g_kwargs.get("color", {})
color_map = color_val if isinstance(color_val, dict) else {}
fontsize_g = g_kwargs.get("fontsize", int(cfg.axis_label_fontsize * 0.9))
left_rot = g_kwargs.get("rotation", 0)
right_rot = g_kwargs.get("rotation_right", 0)
# Helper to find color from direction_colors with fuzzy matching
def _find_direction_color(label: str, direction_colors: dict) -> str | None:
if not direction_colors:
return None
# Exact match first
if label in direction_colors:
return direction_colors[label]
# Fuzzy: check if any direction key is contained in label or vice versa
for key, color in direction_colors.items():
if key in label or label in key:
return color
return None
# Use direction_colors if available, otherwise fall back to palette
left_color = (
color_map.get(left_label)
or _find_direction_color(left_label, cfg.direction_colors)
or cfg.palette.get("sig_down", "#D55E00")
)
right_color = (
color_map.get(right_label)
or _find_direction_color(right_label, cfg.direction_colors)
or cfg.palette.get("sig_up", "#009E73")
)
ax.text(
0.02,
1.02,
left_label,
transform=ax.transAxes,
ha="left",
va="bottom",
fontsize=fontsize_g,
fontweight="bold",
color=left_color,
rotation=left_rot,
)
ax.text(
0.98,
1.02,
right_label,
transform=ax.transAxes,
ha="right",
va="bottom",
fontsize=fontsize_g,
fontweight="bold",
color=right_color,
rotation=right_rot,
)
except Exception:
pass
# local staggering for adjustable texts
for i in range(1, len(adjustable_texts)):
x_prev, y_prev = adjustable_texts[i - 1].get_position()
x_curr, y_curr = adjustable_texts[i].get_position()
if abs(x_curr - x_prev) < 0.2 and abs(y_curr - y_prev) < 0.2:
adjustable_texts[i].set_position((x_curr, y_prev + 0.4))
# apply adjust_text only to adjustable labels
if getattr(cfg, "use_adjust_text", True) and cfg.adjust and adjustable_texts:
with contextlib.suppress(Exception):
adjust_text(
adjustable_texts,
x=[p[0] for p in adjustable_points],
y=[p[1] for p in adjustable_points],
ax=ax,
expand=(2.5, 2.5),
force_text=(1.2, 1.5),
force_points=(0.01, 0.01),
autoalign="xy",
arrowprops=None,
lim=30000,
ensure_inside_axes=True,
)
# Draw connector lines from adjustable points to their text bboxes.
# Forced texts already received straight connectors at placement time.
try:
fig.canvas.draw()
renderer = fig.canvas.get_renderer()
except Exception:
renderer = None
for txt, orig, was_sig in zip(
adjustable_texts, adjustable_points, adjustable_point_sigs, strict=True
):
try:
tx, ty = txt.get_position()
ox, oy = orig
if math.hypot(tx - ox, ty - oy) <= 1e-8:
continue
if renderer is None:
# Attach to marker edge when renderer isn't available; aim
# toward the text position (display coords of tx/ty)
try:
if getattr(cfg, "attach_to_marker_edge", True):
label_disp = ax.transData.transform((tx, ty))
attach_x, attach_y = _marker_edge_data_point(ox, oy, label_disp)
else:
attach_x, attach_y = ox, oy
except Exception:
attach_x, attach_y = ox, oy
try:
conn_color = _select_connector_color(was_sig, ox)
ax.plot(
[attach_x, tx],
[attach_y, ty],
color=conn_color,
linewidth=cfg.connector_width,
alpha=0.8,
zorder=3.5,
)
except Exception:
pass
continue
bbox = txt.get_window_extent(renderer=renderer)
# Convert data point to display coords
point_disp = ax.transData.transform((ox, oy))
# Determine which horizontal edge of the text bbox is closer
# to the point (left or right) and connect to that edge.
left_edge_x = bbox.x0
right_edge_x = bbox.x1
# If point is left of text, attach to left edge; if right, to right edge;
# if inside horizontally, attach to nearest edge.
if point_disp[0] <= left_edge_x:
attach_x = left_edge_x
elif point_disp[0] >= right_edge_x:
attach_x = right_edge_x
else:
# inside horizontally -> choose nearest edge
attach_x = (
left_edge_x
if (point_disp[0] - left_edge_x) < (right_edge_x - point_disp[0])
else right_edge_x
)
attach_y = bbox.y0 + bbox.height / 2.0
attach_data = ax.transData.inverted().transform((attach_x, attach_y))
# compute marker-edge attach point in data coords
try:
if getattr(cfg, "attach_to_marker_edge", True):
# Aim the marker-edge attach point toward the text bbox attach
# display coordinate (attach_x, attach_y) computed above.
label_disp = (attach_x, attach_y)
attach_marker_x, attach_marker_y = _marker_edge_data_point(ox, oy, label_disp)
else:
attach_marker_x, attach_marker_y = ox, oy
except Exception:
attach_marker_x, attach_marker_y = ox, oy
try:
conn_color = _select_connector_color(was_sig, ox)
ax.plot(
[attach_marker_x, attach_data[0]],
[attach_marker_y, attach_data[1]],
color=conn_color,
linewidth=cfg.connector_width,
alpha=0.8,
zorder=3.5,
)
except Exception:
pass
except Exception:
pass
xs = [t.get_position()[0] for t in all_texts]
ys = [t.get_position()[1] for t in all_texts]
if xs and ys:
max_x = max(abs(min(xs)), abs(max(xs)), x_limit)
ax.set_xlim(-max_x - 0.5, max_x + 0.5)
max_y = max(max(ys), y_limit)
ax.set_ylim(bottom=-0.5, top=max_y + 0.5)
# Expand axis limits slightly by the marker radius (converted from
# display pixels to data units) so large markers near the plot edge are
# not visually clipped when saving.
try:
# Only expand limits when the caller did not explicitly provide them.
do_pad_x = getattr(cfg, "xlim", None) is None
do_pad_y = getattr(cfg, "ylim", None) is None
# Respect the caller's preference via cfg.pad_by_marker
if (do_pad_x or do_pad_y) and getattr(cfg, "pad_by_marker", True):
r_points = math.sqrt(max(cfg.marker_size, 1.0)) / 2.0
r_pixels = r_points * fig.dpi / 72.0
# Convert pixel deltas to data-space deltas
zero_data = ax.transData.inverted().transform((0.0, 0.0))
dx_data = ax.transData.inverted().transform((r_pixels, 0.0))[0] - zero_data[0]
dy_data = ax.transData.inverted().transform((0.0, r_pixels))[1] - zero_data[1]
x0, x1 = ax.get_xlim()
y0, y1 = ax.get_ylim()
pad_x = abs(dx_data) + 0.05
pad_y = abs(dy_data) + 0.05
if do_pad_x:
ax.set_xlim(x0 - pad_x, x1 + pad_x)
if do_pad_y:
ax.set_ylim(y0 - pad_y, y1 + pad_y)
except Exception:
pass
# Respect explicit ticks from config if provided, otherwise keep existing logic
if getattr(cfg, "xticks", None) is not None:
with contextlib.suppress(Exception):
ax.set_xticks(list(cfg.xticks))
elif cfg.xtick_step is not None:
left = int(math.floor(ax.get_xlim()[0] / cfg.xtick_step) * cfg.xtick_step)
right = int(math.ceil(ax.get_xlim()[1] / cfg.xtick_step) * cfg.xtick_step)
ax.set_xticks(list(range(left, right + 1, int(cfg.xtick_step))))
if getattr(cfg, "yticks", None) is not None:
with contextlib.suppress(Exception):
ax.set_yticks(list(cfg.yticks))
plt.tight_layout()
return fig, ax
[docs]
class VolcanoPlotter:
"""Stateful, interactive wrapper around the functional API.
Mirrors the interaction pattern of `OncoPlotter`: the instance
exposes `.df` and `.config` attributes, and the constructor accepts
either `(df, config)` or `(config, df)` for backwards compatibility.
Rendering delegates to `plot_volcano` so the pure function remains
the canonical implementation.
"""
def __init__(self, df: pd.DataFrame, config: VolcanoConfig | dict):
"""Construct with `(df, config)` matching `OncoPlotter`.
`config` may be a `VolcanoConfig` or a dict understood by it.
"""
if isinstance(config, dict):
config = VolcanoConfig(**config)
self.df: pd.DataFrame = df.copy()
self.last_df: pd.DataFrame = self.df
self.config: VolcanoConfig = config
self.cfg: VolcanoConfig = config # backward alias
self.fig: plt.Figure | None = None
self.ax: plt.Axes | None = None
# history of explicit annotations added via .annotate()
self.annotation_history: list[dict] = []
# Data / rendering -------------------------------------------------
[docs]
def set_data(self, df: pd.DataFrame) -> VolcanoPlotter:
self.df = df.copy()
self.last_df = self.df
return self
[docs]
def plot(self, df: pd.DataFrame | None = None) -> tuple[plt.Figure, plt.Axes]:
"""Render the volcano. If `df` is provided, set it as the current data.
Returns the `(fig, ax)` produced by `plot_volcano` and stores them
on the instance.
"""
if df is not None:
self.set_data(df)
if self.df is None or self.config is None:
raise RuntimeError(
"Both dataframe and config are required; set them before calling .plot()"
)
# Delegate to the canonical function so behavior stays centralized
self.fig, self.ax = plot_volcano(self.config, self.df)
return self.fig, self.ax
[docs]
def save(self, path: str, **save_kwargs) -> None:
from pathlib import Path
if self.fig is None:
raise RuntimeError("No figure available; call .plot(df) first")
Path(path).parent.mkdir(parents=True, exist_ok=True)
self.fig.savefig(path, **save_kwargs)
# Convenience interactive operations --------------------------------
[docs]
def update_config(self, **kwargs) -> VolcanoPlotter:
"""Update configuration in-place and return self for chaining."""
for k, v in kwargs.items():
try:
setattr(self.cfg, k, v)
except Exception:
continue
return self
[docs]
def annotate(
self,
explicit_positions: Mapping[str, tuple[float, float]],
replace: bool = True,
) -> VolcanoPlotter:
"""Add explicit label placements and re-render.
`explicit_positions` should be a mapping label -> (x, y).
If `replace` is True the explicit placements replace auto labels
(cfg.explicit_label_replace=True). The placements are recorded in
`annotation_history` so callers can inspect what was added.
"""
try:
# normalize to dict[str, (x,y)]
new_map = {str(k): (float(v[0]), float(v[1])) for k, v in explicit_positions.items()}
except Exception:
raise ValueError("explicit_positions must be a mapping label->(x,y)") from None
# record
self.annotation_history.append({"explicit": new_map, "replace": replace})
# apply to config and re-render
try:
if self.config is None:
self.config = VolcanoConfig(
explicit_label_positions=new_map,
explicit_label_replace=bool(replace),
)
else:
self.config.explicit_label_positions = new_map
self.config.explicit_label_replace = bool(replace)
self.cfg = self.config
except Exception:
pass
# Replot with updated config
if self.df is not None:
self.plot(self.df)
return self
[docs]
def label_more(self, n: int = 10) -> VolcanoPlotter:
"""Convenience to expand `cfg.values_to_label` using the internal
resolver -- useful for interactive 'label more' flows.
"""
if self.df is None or self.config is None:
raise RuntimeError("No dataframe available; call .set_data(df) or .plot(df) first")
resolved = resolve_labels(self.df, self.config)
if not resolved:
return self
already = (
list(self.config.values_to_label)
if getattr(self.config, "values_to_label", None)
else []
)
to_add = [v for v in resolved if v not in already][:n]
new_vals = list(dict.fromkeys(already + to_add))
try:
self.config.values_to_label = new_vals
self.cfg = self.config
except Exception:
pass
# re-render
self.plot(self.df)
return self
# Serialization / utilities ----------------------------------------
[docs]
def to_dict(self) -> dict:
try:
return {
"cfg": self.config.model_dump() if self.config is not None else {},
"annotations": list(self.annotation_history),
}
except Exception:
return {"cfg": {}, "annotations": list(self.annotation_history)}
[docs]
@classmethod
def from_dict(cls, data: Mapping) -> VolcanoPlotter:
c = data.get("cfg", {})
vp = cls(c if isinstance(c, VolcanoConfig) else c)
# restore annotation history if present
ah = data.get("annotations", None)
if ah is not None:
try:
vp.annotation_history = list(ah)
except Exception:
vp.annotation_history = []
return vp
[docs]
def close(self) -> None:
try:
if self.fig is not None:
plt.close(self.fig)
finally:
self.fig = None
self.ax = None
[docs]
def resolve_labels(self) -> list[str]:
if self.last_df is None:
raise RuntimeError("No dataframe plotted yet; call .plot(df) first")
return resolve_labels(self.last_df, self.cfg)
def __enter__(self) -> VolcanoPlotter:
return self
def __exit__(self, exc_type, exc, tb) -> None:
self.close()