Source code for goatpy.analysis_gui

from __future__ import annotations

import warnings
import threading
import uuid
from typing import Optional, Callable

import numpy as np
import pandas as pd

# ── Qt ───────────────────────────────────────────────────────────────────────
from qtpy.QtWidgets import (
    QWidget, QVBoxLayout, QHBoxLayout, QLabel, QComboBox,
    QPushButton, QGroupBox, QSizePolicy, QScrollArea, QTabWidget,
    QCheckBox, QSpinBox, QDoubleSpinBox, QApplication, QProgressBar,
    QDialog, QDialogButtonBox, QFileDialog, QListWidget, QListWidgetItem,
    QPlainTextEdit, QTableWidget, QTableWidgetItem, QColorDialog,
)
from qtpy.QtCore import Qt, Signal, QTimer, QThread, QObject
from qtpy.QtGui import QCursor, QColor

# ── matplotlib ───────────────────────────────────────────────────────────────
import matplotlib
matplotlib.use("Qt5Agg")
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from matplotlib.backend_bases import MouseButton
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.ticker as mticker
from matplotlib.lines import Line2D

# ── napari ───────────────────────────────────────────────────────────────────
import napari
from napari.utils.notifications import show_info

# ── spatialdata ──────────────────────────────────────────────────────────────
from spatialdata import SpatialData


# ════════════════════════════════════════════════════════════════════════════
# Colour palette
# ════════════════════════════════════════════════════════════════════════════

[docs] PALETTE = { "bg": "#1e1e2e", "surface": "#2a2a3e", "border": "#3d3d5c", "accent": "#7c6af7", "accent2": "#f5a623", "text": "#cdd6f4", "text_dim": "#6c7086", "spectrum": "#7c6af7", "raw_spec": "#4a5580", # dimmer colour for the dense raw spectrum "peak_marker": "#f38ba8", "highlight": "#fab387", "success": "#a6e3a1", }
_GOATPY_REFS: dict = {}
[docs] CONTINUOUS_CMAPS = [ "inferno", "magma", "plasma", "viridis", "hot", "RdBu_r", "coolwarm", "turbo", "gray", ]
[docs] CATEGORICAL_CMAPS = [ "tab10", "tab20", "Set1", "Set2", "Set3", "Paired", "hsv", "Pastel1", ]
[docs] BASE_STYLE = f""" QWidget {{ background-color: {PALETTE['bg']}; color: {PALETTE['text']}; font-family: 'Segoe UI', 'SF Pro Display', sans-serif; font-size: 11px; }} QGroupBox {{ background-color: {PALETTE['surface']}; border: 1px solid {PALETTE['border']}; border-radius: 6px; margin-top: 8px; padding-top: 8px; font-weight: bold; }} QGroupBox::title {{ subcontrol-origin: margin; left: 8px; padding: 0 4px; color: {PALETTE['accent']}; }} QComboBox {{ background-color: {PALETTE['surface']}; border: 1px solid {PALETTE['border']}; border-radius: 4px; padding: 4px 8px; color: {PALETTE['text']}; }} QComboBox::drop-down {{ border: none; }} QComboBox QAbstractItemView {{ background-color: {PALETTE['surface']}; color: {PALETTE['text']}; selection-background-color: {PALETTE['accent']}; }} QPushButton {{ background-color: {PALETTE['accent']}; color: white; border: none; border-radius: 4px; padding: 6px 12px; font-weight: bold; }} QPushButton:hover {{ background-color: #9b8df8; }} QPushButton:pressed {{ background-color: #5d4ed6; }} QPushButton:disabled {{ background-color: {PALETTE['border']}; color: {PALETTE['text_dim']}; }} QLabel {{ color: {PALETTE['text']}; }} QTabWidget::pane {{ border: 1px solid {PALETTE['border']}; background-color: {PALETTE['surface']}; border-radius: 4px; }} QTabBar::tab {{ background-color: {PALETTE['bg']}; color: {PALETTE['text_dim']}; padding: 6px 14px; border: 1px solid {PALETTE['border']}; border-bottom: none; border-top-left-radius: 4px; border-top-right-radius: 4px; }} QTabBar::tab:selected {{ background-color: {PALETTE['surface']}; color: {PALETTE['text']}; }} QCheckBox {{ color: {PALETTE['text']}; spacing: 6px; }} QScrollArea {{ border: none; background-color: transparent; }} QSpinBox, QDoubleSpinBox {{ background-color: {PALETTE['surface']}; border: 1px solid {PALETTE['border']}; border-radius: 4px; padding: 3px 6px; color: {PALETTE['text']}; }} QProgressBar {{ background-color: {PALETTE['surface']}; border: 1px solid {PALETTE['border']}; border-radius: 3px; text-align: center; color: {PALETTE['text']}; }} QProgressBar::chunk {{ background-color: {PALETTE['accent']}; border-radius: 3px; }} """
# ════════════════════════════════════════════════════════════════════════════ # SpatialData → Napari loader (FIXED — uses napari-spatialdata properly) # ════════════════════════════════════════════════════════════════════════════ def _resolve_element_cs(sdata, element_type: str, name: str, preferred_cs: str = "aligned") -> str: """Pick the best coordinate system for an element (prefer aligned, else global).""" from spatialdata.transformations import get_transformation element = getattr(sdata, element_type)[name] try: transforms = get_transformation(element, get_all=True) cs_names = list(transforms.keys()) if isinstance(transforms, dict) else [] except Exception: cs_names = [] if preferred_cs in cs_names: return preferred_cs if "global" in cs_names: return "global" if cs_names: return cs_names[0] return preferred_cs def _add_spatialdata_layers(viewer, sdata, target_cs="aligned"): """ Load every SpatialData element into the viewer via napari-spatialdata. Each element is added in its own best coordinate system so shapes like 'pixels' and 'Annotations' are all available even when they differ. """ from napari_spatialdata import Interactive interactive = Interactive(sdata, headless=True) for element_type in ("images", "labels", "shapes", "points"): container = getattr(sdata, element_type, {}) for name in container: cs = _resolve_element_cs(sdata, element_type, name, target_cs) try: interactive.add_element( element=name, element_coordinate_system=cs, view_element_system=False, ) print(f"[goatpy GUI] Added '{name}' ({element_type}) in '{cs}'") except Exception as e: print(f"[goatpy GUI] Could not add '{name}' ({element_type}): {e}") try: interactive.switch_coordinate_system(target_cs) except Exception as e: print(f"[goatpy GUI] Could not switch coordinate system: {e}") return interactive # ════════════════════════════════════════════════════════════════════════════ # m/z resolution helpers # ════════════════════════════════════════════════════════════════════════════ def _resolve_mz_array(adata) -> np.ndarray: """ Return float64 m/z array for every adata column, handling: 1. adata.var["mz_original"] (set by annotate_glycans) 2. "mz-933.4" prefixed var_names 3. plain numeric var_names 4. column index fallback """ n = adata.n_vars if "mz_original" in adata.var.columns: try: arr = pd.to_numeric(adata.var["mz_original"], errors="coerce").values if not np.any(np.isnan(arr)): return arr.astype(np.float64) except Exception: pass mzs = np.full(n, np.nan, dtype=np.float64) for i, vn in enumerate(adata.var_names): s = str(vn).strip() if s.lower().startswith("mz-"): s = s[3:] try: mzs[i] = float(s) except ValueError: pass nan_mask = np.isnan(mzs) if nan_mask.any(): mzs[nan_mask] = np.where(nan_mask)[0].astype(np.float64) return mzs def _resolve_var_display_labels(adata) -> list[str]: mzs = _resolve_mz_array(adata) labels = [] for i, vn in enumerate(adata.var_names): s = str(vn).strip() labels.append(f"{mzs[i]:.4f}" if s.lower().startswith("mz-") else s) return labels def _looks_numeric(s: str) -> bool: s2 = s.strip() if s2.lower().startswith("mz-"): s2 = s2[3:] try: float(s2) return True except ValueError: return False # ════════════════════════════════════════════════════════════════════════════ # Napari layer helper — render glycan ion image ON the H&E # ════════════════════════════════════════════════════════════════════════════
[docs] GLYCAN_LAYER_NAME = "glycan_ion_map"
[docs] GOATPY_VIZ_KEY = "_goatpy_viz"
def _find_shapes_layer(viewer, shapes_name: str): """Match a napari Shapes layer by element name (handles napari-spatialdata suffixes).""" from napari.layers import Shapes for lyr in viewer.layers: if not isinstance(lyr, Shapes): continue n = lyr.name if (n == shapes_name or n.startswith(shapes_name + " [") or n.startswith(shapes_name + ":") or shapes_name in n): return lyr return None def _is_obs_categorical(adata, col: str) -> bool: series = adata.obs[col] if series.dtype.name == "category" or series.dtype == bool: return True n_unique = series.nunique(dropna=True) if pd.api.types.is_integer_dtype(series) and n_unique <= 64: return True if pd.api.types.is_numeric_dtype(series): return n_unique <= min(16, max(1, adata.n_obs // 20)) return n_unique <= min(50, max(1, adata.n_obs // 20)) def _is_series_categorical(series) -> bool: if series.dtype.name == "category" or series.dtype == bool: return True n_unique = pd.Series(series).nunique(dropna=True) if pd.api.types.is_integer_dtype(series) and n_unique <= 64: return True if pd.api.types.is_numeric_dtype(series): return n_unique <= min(16, max(1, len(series) // 20)) return n_unique <= min(50, max(1, len(series) // 20)) def _shape_to_dataframe(shapes): if shapes is None: return None if hasattr(shapes, "columns"): return shapes if hasattr(shapes, "to_dataframe"): return shapes.to_dataframe() if hasattr(shapes, "compute"): return shapes.compute() return None def _normalize_shapes_rgba(colours): arr = np.asarray(colours, dtype=np.float32) if arr.ndim != 2 or arr.shape[1] < 3: return np.zeros((len(colours), 4), dtype=np.float32) if arr.max() > 1.1: arr = arr / 255.0 if arr.shape[1] == 3: alpha = np.ones((arr.shape[0], 1), dtype=arr.dtype) arr = np.concatenate([arr, alpha], axis=1) return arr def _rgb_to_hex(rgb) -> str: """Convert [r,g,b] (0-255 or 0-1) to '#rrggbb'.""" arr = np.asarray(rgb, dtype=np.float64).ravel()[:3] if arr.max() <= 1.0: arr = arr * 255.0 arr = np.clip(arr, 0, 255).astype(int) return "#{:02x}{:02x}{:02x}".format(arr[0], arr[1], arr[2]) def _hex_to_rgb(hex_str: str) -> list[int]: """Convert '#rrggbb' (or 'rrggbb') to [r, g, b] ints (0-255).""" s = str(hex_str).lstrip("#") if len(s) != 6: return [128, 128, 128] try: return [int(s[0:2], 16), int(s[2:4], 16), int(s[4:6], 16)] except ValueError: return [128, 128, 128] def _safe_is_series_categorical(series) -> bool: """Like _is_series_categorical but tolerant of unhashable / malformed columns.""" try: # Reject columns whose values aren't hashable (e.g. lists/arrays of RGB) sample = series.iloc[0] if len(series) else None if isinstance(sample, (list, tuple, np.ndarray)): return False return _is_series_categorical(series) except TypeError: return False except Exception: return False def _apply_direct_shapes_colors(layer, colours): rgba = _normalize_shapes_rgba(colours) layer.face_color = rgba layer.face_color_mode = "direct" layer.edge_color = "transparent" layer.edge_width = 0.0 layer.refresh_colors(update_color_mapping=True) layer.refresh() def _categorical_cycle_colors(colormap: str, n: int) -> np.ndarray: """Build an RGBA color cycle from a matplotlib categorical colormap.""" cmap = plt.get_cmap(colormap) if hasattr(cmap, "colors") and cmap.colors: base = np.array(cmap.colors) if len(base) < n: base = np.tile(base, (int(np.ceil(n / len(base))), 1))[:n] else: base = base[:n] else: if n <= 0: return np.zeros((0, 4), dtype=float) base = np.array([cmap(i / max(n - 1, 1)) for i in range(n)]) return base def _apply_shapes_colormap(layer, colormap: str, categorical: bool, n_categories: int = 0): """Apply colormap / color-cycle to an already-mapped shapes layer and refresh.""" if categorical: n = max(n_categories, 1) layer.face_color_cycle = _categorical_cycle_colors(colormap, n) layer.face_color_mode = "cycle" else: layer.face_colormap = colormap layer.face_color_mode = "colormap" layer.refresh_colors(update_color_mapping=True) layer.refresh() def _draw_spatial_legend( canvas: MplCanvas, title: str, *, categorical: bool, colormap: str, categories: Optional[list] = None, vmin: Optional[float] = None, vmax: Optional[float] = None, ): """Draw a matplotlib legend (categorical) or colorbar (continuous) in the sidebar.""" canvas.fig.clear() ax = canvas.fig.add_subplot(111) canvas._style_ax(ax) ax.set_axis_off() if categorical and categories: colors = _categorical_cycle_colors(colormap, len(categories)) handles = [ mpatches.Patch( facecolor=tuple(np.asarray(colors[i]).ravel()), edgecolor=PALETTE["border"], label=str(cat), ) for i, cat in enumerate(categories) ] ax.legend( handles=handles, title=title, loc="center", fontsize=7, title_fontsize=8, framealpha=0.35, facecolor=PALETTE["surface"], edgecolor=PALETTE["border"], labelcolor=PALETTE["text"], ) elif vmin is not None and vmax is not None: sm = plt.cm.ScalarMappable( cmap=colormap, norm=plt.Normalize(vmin=vmin, vmax=vmax), ) sm.set_array([]) cb = canvas.fig.colorbar( sm, ax=ax, fraction=0.85, pad=0.02, orientation="horizontal", ) cb.ax.xaxis.set_tick_params(color=PALETTE["text_dim"], labelsize=7) cb.outline.set_edgecolor(PALETTE["border"]) plt.setp(cb.ax.xaxis.get_ticklabels(), color=PALETTE["text"]) cb.set_label(title, color=PALETTE["text"], fontsize=8) else: ax.text(0.5, 0.5, "No legend", ha="center", va="center", transform=ax.transAxes, color=PALETTE["text_dim"], fontsize=8) canvas.fig.tight_layout(pad=0.3) canvas.draw() def _render_values_on_shapes( viewer, values, colormap: str, categorical: bool, label: str, shapes_name: str = "pixels", ) -> Optional[dict]: """ Colour a shapes layer by per-instance values. Returns render metadata for legend updates, or None on failure. """ layer = _find_shapes_layer(viewer, shapes_name) if layer is None: from napari.layers import Shapes available = [l.name for l in viewer.layers if isinstance(l, Shapes)] show_info( f"No Shapes layer matching '{shapes_name}' found. " f"Available Shapes layers: {available}" ) return None values = np.asarray(values).ravel() if len(layer.data) != len(values): show_info( f"Shape count ({len(layer.data)}) ≠ value count ({len(values)}). " "Check obs ordering / region linkage." ) return None if categorical: props = np.asarray(values, dtype=object).astype(str) categories = list(np.unique(props)) n_categories = len(categories) vmin = vmax = None else: props = pd.to_numeric(values, errors="coerce").astype(np.float32) categories = None n_categories = 0 valid = props[np.isfinite(props)] if len(valid): vmin = float(np.percentile(valid, 1)) vmax = float(np.percentile(valid, 99)) else: vmin, vmax = 0.0, 1.0 # Single internal property key avoids stale napari color-cycle / colormap caches. layer.properties = {GOATPY_VIZ_KEY: props} layer.opacity = 0.75 # Set colormapping parameters first, before assigning the face_color property if categorical: layer.face_color_mode = "cycle" layer.face_color_cycle = _categorical_cycle_colors(colormap, n_categories) layer.face_contrast_limits = None else: layer.face_color_mode = "colormap" layer.face_colormap = colormap layer.face_contrast_limits = (vmin, vmax) # Now assign the property name for coloring layer.face_color = GOATPY_VIZ_KEY # Ensure edge color doesn't obscure the face coloring layer.edge_color = "transparent" layer.edge_width = 0.0 layer.refresh_colors(update_color_mapping=True) layer.refresh() show_info(f"Layer updated: {label}") return { "categorical": categorical, "colormap": colormap, "categories": categories, "vmin": vmin, "vmax": vmax, "shapes_name": shapes_name, "label": label, } # ════════════════════════════════════════════════════════════════════════════ # Ion-image helper — compute per-pixel intensities for an arbitrary m/z and # render them on the existing "pixels" Shapes layer (guaranteed alignment # with the H&E, since it reuses the same registered shapes/transform). # ════════════════════════════════════════════════════════════════════════════ def _compute_unregistered_ion_values(sdata, imzml_path: str, target_mz: float, tol_da: float, table_name: str = "maldi_adata"): """ For an arbitrary m/z (+/- tol_da), compute a per-pixel intensity value ordered to match adata.obs / the "pixels" shapes layer. Mirrors the coordinate convention used in glyco_spatialdata(): full_x, full_y = p.coordinates[:, :2] - 1 (0-based) x = full_x - full_x.min(); y = full_y - full_y.min() Each spectrum index i then corresponds to obs row with the same (x, y). """ from pyimzml.ImzMLParser import ImzMLParser p = ImzMLParser(imzml_path) coords = np.array(p.coordinates)[:, :2].astype(np.int64) - 1 # 0-based (full_x, full_y) x = coords[:, 0] - coords[:, 0].min() y = coords[:, 1] - coords[:, 1].min() # Per-spectrum summed intensity within [target_mz - tol, target_mz + tol] n = len(p.coordinates) per_spectrum = np.zeros(n, dtype=np.float64) lo, hi = target_mz - tol_da, target_mz + tol_da for i in range(n): mzs, ints = p.getspectrum(i) mzs = np.asarray(mzs) ints = np.asarray(ints, dtype=np.float64) mask = (mzs >= lo) & (mzs <= hi) if mask.any(): per_spectrum[i] = ints[mask].sum() # Map (x, y) -> intensity xy_to_val: dict[tuple[int, int], float] = {} for i in range(n): xy_to_val[(int(x[i]), int(y[i]))] = per_spectrum[i] adata = sdata.tables[table_name] obs_x = adata.obs["x"].to_numpy(dtype=np.int64) obs_y = adata.obs["y"].to_numpy(dtype=np.int64) values = np.array( [xy_to_val.get((int(xi), int(yi)), 0.0) for xi, yi in zip(obs_x, obs_y)], dtype=np.float32, ) return values def _display_unregistered_ion_image_on_shapes( viewer, sdata, imzml_path: str, target_mz: float, tol_da: float, label: str, table_name: str = "maldi_adata", shapes_name: str = "pixels", colormap: str = "inferno", ): """ Compute per-pixel intensities for an arbitrary m/z (+/- tol) and render them on the existing 'pixels' Shapes layer — same mechanism used for curated glycan ion maps, so alignment with the H&E is guaranteed. """ try: values = _compute_unregistered_ion_values( sdata, imzml_path, target_mz, tol_da, table_name=table_name ) except Exception as e: show_info(f"Could not compute ion image: {e}") return None state = _render_values_on_shapes( viewer, values, colormap, False, f"{label} (\u00b1{tol_da:.4f})", shapes_name, ) if state is not None: show_info(f"Ion image displayed for {label} (m/z {target_mz:.4f} \u00b1 {tol_da:.4f})") return state def _render_glycan_on_viewer( viewer, sdata, peak_mz, label, table_name="maldi_adata", shapes_name="pixels", colormap: str = "inferno", ): adata = sdata.tables[table_name] var_mzs = _resolve_mz_array(adata) idx = int(np.argmin(np.abs(var_mzs - peak_mz))) if abs(var_mzs[idx] - peak_mz) > 0.5: show_info(f"Peak {peak_mz:.2f} not found") return values = np.asarray(adata.X[:, idx]).astype(np.float32).ravel() return _render_values_on_shapes( viewer, values, colormap, False, f"{label} ({peak_mz:.2f})", shapes_name, ) class _SpectrumLoader(QObject): """ Loads the mean spectrum from the raw imzML file on a worker thread, then emits finished(mz_array, intensity_array). """ finished = Signal(object, object) # np.ndarray, np.ndarray progress = Signal(int) # 0-100 def __init__(self, imzml_path: str, n_sample: int = 500): super().__init__() self.imzml_path = imzml_path self.n_sample = n_sample # max spectra to average (for speed) def run(self): try: from pyimzml.ImzMLParser import ImzMLParser p = ImzMLParser(self.imzml_path) n_total = len(p.coordinates) step = max(1, n_total // self.n_sample) indices = list(range(0, n_total, step)) # First pass: discover global m/z range all_mzs = [] for i, idx in enumerate(indices): mzs, _ = p.getspectrum(idx) all_mzs.append(mzs) if i % max(1, len(indices) // 20) == 0: self.progress.emit(int(i / len(indices) * 50)) lo = min(a[0] for a in all_mzs if len(a)) hi = max(a[-1] for a in all_mzs if len(a)) # Build uniform 0.05 Da grid bin_edges = np.arange(lo, hi + 0.05, 0.05) acc = np.zeros(len(bin_edges) - 1, dtype=np.float64) counts = np.zeros(len(bin_edges) - 1, dtype=np.int32) for i, (idx, mzs) in enumerate(zip(indices, all_mzs)): _, ints = p.getspectrum(idx) if len(mzs) == 0: continue bin_idx = np.searchsorted(bin_edges, mzs, side="right") - 1 valid = (bin_idx >= 0) & (bin_idx < len(acc)) np.add.at(acc, bin_idx[valid], np.asarray(ints, dtype=np.float64)[valid]) np.add.at(counts, bin_idx[valid], 1) if i % max(1, len(indices) // 20) == 0: self.progress.emit(50 + int(i / len(indices) * 50)) with np.errstate(divide="ignore", invalid="ignore"): mean_spec = np.where(counts > 0, acc / counts, 0.0) bin_centres = (bin_edges[:-1] + bin_edges[1:]) / 2.0 self.progress.emit(100) self.finished.emit(bin_centres, mean_spec) except Exception as e: print(f"[SpectrumLoader] Error: {e}") self.finished.emit(np.array([]), np.array([])) # ════════════════════════════════════════════════════════════════════════════ # MplCanvas # ════════════════════════════════════════════════════════════════════════════
[docs] class MplCanvas(FigureCanvas): def __init__(self, parent=None, width=8, height=3, dpi=90):
[docs] self.fig = Figure(figsize=(width, height), dpi=dpi)
self.fig.patch.set_facecolor(PALETTE["surface"]) super().__init__(self.fig) self.setParent(parent) self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) self.setMinimumHeight(180) def _style_ax(self, ax): ax.set_facecolor(PALETTE["bg"]) for spine in ax.spines.values(): spine.set_color(PALETTE["border"]) ax.tick_params(colors=PALETTE["text_dim"], labelsize=8) ax.xaxis.label.set_color(PALETTE["text"]) ax.yaxis.label.set_color(PALETTE["text"]) ax.title.set_color(PALETTE["text"]) return ax
# ════════════════════════════════════════════════════════════════════════════ # 1. SPECTRUM WIDGET (bottom dock) # • Loads full raw spectrum from imzML in background # • Scroll = pan, Ctrl+scroll = zoom # • Click near a red peak line → selects that glycan # ════════════════════════════════════════════════════════════════════════════
[docs] class SpectrumWidget(QWidget): """Interactive spectrum panel — full imzML background + clickable peaks.""" # Emitted when user clicks a peak; consumed by AnalysisSidebar + viewer
[docs] peak_clicked = Signal(float, str) # mz, display_label
# Emitted when user requests "Display spatially" for an unregistered peak
[docs] unregistered_peak_display = Signal(float, float) # mz, tolerance (Da)
def __init__( self, sdata: SpatialData, peaks: list[float], glycan_df: Optional[pd.DataFrame] = None, table_name: str = "maldi_adata", applied_tolerance: float = 0.1, parent=None, ): super().__init__(parent)
[docs] self.sdata = sdata
[docs] self.peaks = sorted(peaks)
[docs] self.glycan_df = glycan_df
[docs] self.table_name = table_name
# ── Tolerance originally applied when peaks were extracted into the # raw SpatialData object (glyco_spatialdata's `tol` argument) ──
[docs] self.applied_tolerance: float = applied_tolerance
try: uns_tol = self.sdata.tables[self.table_name].uns.get("maldi_tolerance") if uns_tol is not None: self.applied_tolerance = float(uns_tol) except Exception: pass
[docs] self.highlighted_mz: Optional[float] = None
[docs] self.highlighted_label: str = ""
self._tol = 0.15 # ── Unregistered-peak picking mode ────────────────────────────── self._unreg_mode: bool = False self._unreg_mz: Optional[float] = None self._unreg_tol: float = 0.25 # ── m/z list (curated peaks + any added via "Add peak to list") ──
[docs] self.peak_list: list[float] = list(self.peaks)
# Raw full spectrum (from imzML) self._raw_mz: Optional[np.ndarray] = None self._raw_int: Optional[np.ndarray] = None # sdata mean spectrum (fallback / overlay reference) self._sdata_mz: Optional[np.ndarray] = None self._sdata_int: Optional[np.ndarray] = None # View window self._view_lo: float = 500.0 self._view_hi: float = 3000.0 # Peak label lookup mz → display string self._peak_labels: dict[float, str] = {} self._build_peak_labels() self._build_ui() QTimer.singleShot(100, self._load_sdata_spectrum) QTimer.singleShot(200, self._start_imzml_load) # ── Peak label lookup ───────────────────────────────────────────────── def _build_peak_labels(self): """Map each curated m/z to its best glycan name.""" try: adata = self.sdata.tables[self.table_name] mz_arr = _resolve_mz_array(adata) disp = _resolve_var_display_labels(adata) for mz, lbl in zip(mz_arr, disp): self._peak_labels[float(mz)] = lbl except Exception: pass if self.glycan_df is not None: for _, row in self.glycan_df.iterrows(): mz = float(row["mz"]) lbl = str(row["label"]) if lbl and lbl not in ("nan", ""): self._peak_labels[mz] = lbl def _label_for_peak(self, mz: float) -> str: best = min(self._peak_labels.keys(), key=lambda m: abs(m - mz), default=None) if best is not None and abs(best - mz) < 0.5: return self._peak_labels[best] return f"{mz:.4f}" # ── UI ──────────────────────────────────────────────────────────────── def _build_ui(self): layout = QVBoxLayout(self) layout.setContentsMargins(4, 4, 4, 4) layout.setSpacing(2) # ── Controls bar ───────────────────────────────────────────────── ctrl = QHBoxLayout() lbl = QLabel("Spectrum source:") lbl.setFixedWidth(110) ctrl.addWidget(lbl) self.source_combo = QComboBox() self.source_combo.addItems(["Raw imzML (full)", "sdata mean"]) self.source_combo.setFixedWidth(160) self.source_combo.currentTextChanged.connect(self._redraw) ctrl.addWidget(self.source_combo) ctrl.addSpacing(16) ctrl.addWidget(QLabel("Show peaks:")) self.show_peaks_cb = QCheckBox() self.show_peaks_cb.setChecked(True) self.show_peaks_cb.stateChanged.connect(self._redraw) ctrl.addWidget(self.show_peaks_cb) ctrl.addSpacing(16) ctrl.addWidget(QLabel("Zoom to:")) self.mz_lo = QDoubleSpinBox() self.mz_lo.setRange(50, 10000) self.mz_lo.setValue(500) self.mz_lo.setSingleStep(50) self.mz_lo.setFixedWidth(78) ctrl.addWidget(self.mz_lo) ctrl.addWidget(QLabel("–")) self.mz_hi = QDoubleSpinBox() self.mz_hi.setRange(50, 10000) self.mz_hi.setValue(3000) self.mz_hi.setSingleStep(50) self.mz_hi.setFixedWidth(78) ctrl.addWidget(self.mz_hi) zoom_btn = QPushButton("Go") zoom_btn.setFixedWidth(40) zoom_btn.clicked.connect(self._apply_zoom) ctrl.addWidget(zoom_btn) reset_btn = QPushButton("Reset") reset_btn.setFixedWidth(55) reset_btn.clicked.connect(self._reset_zoom) ctrl.addWidget(reset_btn) ctrl.addSpacing(16) self.show_applied_tol_cb = QCheckBox("Show applied tolerance") self.show_applied_tol_cb.setChecked(False) self.show_applied_tol_cb.stateChanged.connect(self._redraw) ctrl.addWidget(self.show_applied_tol_cb) self.applied_tol_spin = QDoubleSpinBox() self.applied_tol_spin.setRange(0.001, 50.0) self.applied_tol_spin.setDecimals(3) self.applied_tol_spin.setSingleStep(0.01) self.applied_tol_spin.setValue(self.applied_tolerance) self.applied_tol_spin.setFixedWidth(78) self.applied_tol_spin.valueChanged.connect(self._on_applied_tol_changed) ctrl.addWidget(self.applied_tol_spin) ctrl.addStretch() self.status_lbl = QLabel("Loading…") self.status_lbl.setStyleSheet(f"color: {PALETTE['text_dim']};") ctrl.addWidget(self.status_lbl) layout.addLayout(ctrl) # ── Unregistered peak controls ──────────────────────────────────── unreg = QHBoxLayout() self.unreg_btn = QPushButton("Check unregistered peak") self.unreg_btn.setCheckable(True) self.unreg_btn.toggled.connect(self._on_unreg_toggled) unreg.addWidget(self.unreg_btn) unreg.addSpacing(12) self.unreg_tol_lbl = QLabel("± Tolerance (Da):") unreg.addWidget(self.unreg_tol_lbl) self.unreg_tol_spin = QDoubleSpinBox() self.unreg_tol_spin.setRange(0.001, 50.0) self.unreg_tol_spin.setDecimals(3) self.unreg_tol_spin.setSingleStep(0.01) self.unreg_tol_spin.setValue(self._unreg_tol) self.unreg_tol_spin.setFixedWidth(90) self.unreg_tol_spin.valueChanged.connect(self._on_unreg_tol_changed) unreg.addWidget(self.unreg_tol_spin) unreg.addSpacing(12) self.unreg_selected_lbl = QLabel("No peak selected") self.unreg_selected_lbl.setStyleSheet(f"color: {PALETTE['text_dim']};") unreg.addWidget(self.unreg_selected_lbl) unreg.addStretch() self.display_spatially_btn = QPushButton("Display spatially") self.display_spatially_btn.setEnabled(False) self.display_spatially_btn.clicked.connect(self._on_display_spatially) unreg.addWidget(self.display_spatially_btn) self.add_peak_btn = QPushButton("Add peak to list") self.add_peak_btn.setEnabled(False) self.add_peak_btn.clicked.connect(self._on_add_peak_to_list) unreg.addWidget(self.add_peak_btn) self.export_peaks_btn = QPushButton("Export list") self.export_peaks_btn.clicked.connect(self._on_export_peak_list) unreg.addWidget(self.export_peaks_btn) layout.addLayout(unreg) # Hide unregistered-peak controls until the mode is enabled self.unreg_tol_lbl.setVisible(False) self.unreg_tol_spin.setVisible(False) self.unreg_selected_lbl.setVisible(False) self.display_spatially_btn.setVisible(False) self.add_peak_btn.setVisible(False) # ── Progress bar (shown while imzML loads) ──────────────────────── self.progress_bar = QProgressBar() self.progress_bar.setFixedHeight(4) self.progress_bar.setTextVisible(False) self.progress_bar.hide() layout.addWidget(self.progress_bar) # ── Hint label ──────────────────────────────────────────────────── hint = QLabel( "Scroll to pan · Ctrl+scroll to zoom · Click a red peak line to select glycan · " "'Check unregistered peak' lets you click anywhere on the spectrum" ) hint.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 9px;") layout.addWidget(hint) # ── Canvas ──────────────────────────────────────────────────────── self.canvas = MplCanvas(self, width=10, height=2.6, dpi=90) layout.addWidget(self.canvas) # ── Interactions ────────────────────────────────────────────────── self.canvas.mpl_connect("scroll_event", self._on_scroll) self.canvas.mpl_connect("button_press_event", self._on_click) # ── Data loading ────────────────────────────────────────────────────── def _load_sdata_spectrum(self): try: adata = self.sdata.tables[self.table_name] mz_arr = _resolve_mz_array(adata) X = np.asarray(adata.X, dtype=np.float32) intensities = X.mean(axis=0).astype(np.float64) self._sdata_mz = mz_arr self._sdata_int = intensities self._view_lo = float(mz_arr.min()) self._view_hi = float(mz_arr.max()) self.mz_lo.setValue(self._view_lo) self.mz_hi.setValue(self._view_hi) self.status_lbl.setText( f"{adata.n_obs:,} pixels · {adata.n_vars:,} peaks | " "Loading full spectrum…" ) except Exception as e: self.status_lbl.setText(f"sdata load error: {e}") self._redraw() def _start_imzml_load(self): try: path = self.sdata.tables[self.table_name].uns.get("maldi_path") except Exception: path = None if not path: self.status_lbl.setText( self.status_lbl.text().replace("Loading full spectrum…", "No imzML path found") ) return self.progress_bar.show() self.progress_bar.setValue(0) self._loader = _SpectrumLoader(path, n_sample=800) self._thread = QThread() self._loader.moveToThread(self._thread) self._thread.started.connect(self._loader.run) self._loader.finished.connect(self._on_imzml_loaded) self._loader.progress.connect(self.progress_bar.setValue) self._loader.finished.connect(self._thread.quit) self._thread.start() def _on_imzml_loaded(self, mz_arr, int_arr): self.progress_bar.hide() if len(mz_arr) == 0: self.status_lbl.setText("imzML load failed — using sdata mean") return self._raw_mz = mz_arr self._raw_int = int_arr # Switch to raw source automatically self.source_combo.setCurrentText("Raw imzML (full)") # Update view window to full range self._view_lo = float(mz_arr.min()) self._view_hi = float(mz_arr.max()) self.mz_lo.setValue(self._view_lo) self.mz_hi.setValue(self._view_hi) try: adata = self.sdata.tables[self.table_name] self.status_lbl.setText( f"{adata.n_obs:,} pixels · {adata.n_vars:,} peaks | " f"Full spectrum: {len(mz_arr):,} bins" ) except Exception: pass self._redraw() # ── Unregistered peak mode ────────────────────────────────────────────── def _on_unreg_toggled(self, checked: bool): self._unreg_mode = checked self.unreg_tol_lbl.setVisible(checked) self.unreg_tol_spin.setVisible(checked) self.unreg_selected_lbl.setVisible(checked) self.display_spatially_btn.setVisible(checked) self.add_peak_btn.setVisible(checked) if checked: self.unreg_btn.setText("Exit unregistered-peak mode") self.unreg_selected_lbl.setText("Click a peak on the spectrum to select it") self.display_spatially_btn.setEnabled(False) self.add_peak_btn.setEnabled(False) else: self.unreg_btn.setText("Check unregistered peak") self._unreg_mz = None self.display_spatially_btn.setEnabled(False) self.add_peak_btn.setEnabled(False) self._redraw() def _on_applied_tol_changed(self, value: float): self.applied_tolerance = float(value) self._redraw() def _on_unreg_tol_changed(self, value: float): self._unreg_tol = float(value) if self._unreg_mz is not None: self.unreg_selected_lbl.setText( f"Selected m/z: {self._unreg_mz:.4f}{self._unreg_tol:.3f} Da)" ) self._redraw() def _on_display_spatially(self): if self._unreg_mz is None: return self.unregistered_peak_display.emit(self._unreg_mz, self._unreg_tol) def _on_add_peak_to_list(self): if self._unreg_mz is None: return mz = float(self._unreg_mz) # Avoid near-duplicate entries (within 1e-4 Da) if any(abs(mz - existing) < 1e-4 for existing in self.peak_list): show_info(f"m/z {mz:.4f} is already in the list.") return self.peak_list.append(mz) self.peak_list.sort() show_info(f"Added m/z {mz:.4f} to list (n={len(self.peak_list)}).") def _on_export_peak_list(self): from qtpy.QtWidgets import QFileDialog if not self.peak_list: show_info("Peak list is empty — nothing to export.") return path, _ = QFileDialog.getSaveFileName( self, "Export m/z list", "peak_list.csv", "CSV files (*.csv)" ) if not path: return try: df = pd.DataFrame({"m/z": sorted(self.peak_list)}) df.to_csv(path, index=False) show_info(f"Exported {len(self.peak_list)} m/z values to {path}") except Exception as e: show_info(f"Export failed: {e}") # ── Interactions ────────────────────────────────────────────────────── def _on_scroll(self, event): """ Plain scroll → pan (shift view left/right). Ctrl+scroll → zoom in/out around cursor position. """ if self._current_mz() is None: return span = self._view_hi - self._view_lo if span <= 0: return ctrl_held = (event.key == "control") or ( QApplication.keyboardModifiers() & Qt.ControlModifier ) if ctrl_held: # Zoom: shrink/expand around cursor factor = 0.85 if event.button == "up" else 1.0 / 0.85 cursor_mz = event.xdata if event.xdata is not None else (self._view_lo + self._view_hi) / 2 new_lo = cursor_mz - (cursor_mz - self._view_lo) * factor new_hi = cursor_mz + (self._view_hi - cursor_mz) * factor else: # Pan: move 15% of span per scroll tick shift = span * 0.15 * (-1 if event.button == "up" else 1) new_lo = self._view_lo + shift new_hi = self._view_hi + shift # Clamp to data extent mz = self._current_mz() data_lo, data_hi = float(mz.min()), float(mz.max()) width = new_hi - new_lo new_lo = max(data_lo, new_lo) new_hi = min(data_hi, new_hi) if new_hi - new_lo < 5: new_lo = new_hi - 5 self._view_lo = new_lo self._view_hi = new_hi self.mz_lo.setValue(new_lo) self.mz_hi.setValue(new_hi) self._redraw() def _on_click(self, event): """Left-click: find nearest curated peak within a tolerance and select it, or — in unregistered-peak mode — select the clicked m/z directly.""" if event.button != MouseButton.LEFT: return if event.xdata is None: return click_mz = float(event.xdata) if self._unreg_mode: mz = self._current_mz() if mz is not None: data_lo, data_hi = float(mz.min()), float(mz.max()) click_mz = max(data_lo, min(data_hi, click_mz)) self._unreg_mz = click_mz self.unreg_selected_lbl.setText( f"Selected m/z: {click_mz:.4f}{self._unreg_tol:.3f} Da)" ) self.display_spatially_btn.setEnabled(True) self.add_peak_btn.setEnabled(True) self._redraw() return if not self.show_peaks_cb.isChecked(): return span = self._view_hi - self._view_lo snap_tol = span * 0.015 # 1.5% of visible window # Find nearest curated peak within snap_tol visible_peaks = [p for p in self.peaks if self._view_lo <= p <= self._view_hi] if not visible_peaks: return nearest = min(visible_peaks, key=lambda p: abs(p - click_mz)) if abs(nearest - click_mz) > snap_tol: return label = self._label_for_peak(nearest) self.highlight_glycan(nearest, label) self.peak_clicked.emit(nearest, label) # ── View helpers ────────────────────────────────────────────────────── def _current_mz(self) -> Optional[np.ndarray]: src = self.source_combo.currentText() if src == "Raw imzML (full)" and self._raw_mz is not None: return self._raw_mz return self._sdata_mz def _apply_zoom(self): self._view_lo = self.mz_lo.value() self._view_hi = self.mz_hi.value() self._redraw() def _reset_zoom(self): mz = self._current_mz() if mz is not None: self._view_lo = float(mz.min()) self._view_hi = float(mz.max()) self.mz_lo.setValue(self._view_lo) self.mz_hi.setValue(self._view_hi) self._redraw() # ── Drawing ────────────────────────────────────────────────────────── def _redraw(self): src = self.source_combo.currentText() if src == "Raw imzML (full)" and self._raw_mz is not None: mz, intensity = self._raw_mz, self._raw_int elif self._sdata_mz is not None: mz, intensity = self._sdata_mz, self._sdata_int else: return lo, hi = self._view_lo, self._view_hi mask = (mz >= lo) & (mz <= hi) mz_v = mz[mask] int_v = intensity[mask] if len(mz_v) == 0: return self.canvas.fig.clear() ax = self.canvas.fig.add_subplot(111) self.canvas._style_ax(ax) # ── Background spectrum ─────────────────────────────────────────── spec_col = PALETTE["raw_spec"] if src == "Raw imzML (full)" else PALETTE["spectrum"] ax.plot(mz_v, int_v, color=spec_col, linewidth=0.6, alpha=0.8, zorder=2) ax.fill_between(mz_v, int_v, color=spec_col, alpha=0.08, zorder=1) max_int = float(int_v.max()) if len(int_v) else 1.0 # ── Curated peaks (red dashed lines) ───────────────────────────── if self.show_peaks_cb.isChecked(): for pk in self.peaks: if lo <= pk <= hi: ax.axvline( pk, color=PALETTE["peak_marker"], linewidth=1.0, linestyle="--", alpha=0.7, zorder=3, picker=5, ) # ── Applied tolerance (red outline: baseline -> peak -> baseline) ── if self.show_applied_tol_cb.isChecked(): tol = self.applied_tolerance for pk in self.peaks: if (pk - tol) <= hi and (pk + tol) >= lo: # Height of the spectrum at this peak's m/z (interpolated) pk_height = float(np.interp(pk, mz_v, int_v)) if len(mz_v) else 0.0 x_left, x_right = pk - tol, pk + tol # Smooth rise/fall using a half-cosine profile so the line # leaves/returns to baseline tangentially. n = 30 t_up = np.linspace(0, 1, n) x_up = x_left + (pk - x_left) * t_up y_up = pk_height * (1 - np.cos(np.pi * t_up)) / 2 t_down = np.linspace(0, 1, n) x_down = pk + (x_right - pk) * t_down y_down = pk_height * (1 + np.cos(np.pi * t_down)) / 2 curve_x = np.concatenate([x_up, x_down[1:]]) curve_y = np.concatenate([y_up, y_down[1:]]) ax.plot(curve_x, curve_y, color="red", linewidth=1.2, alpha=0.85, zorder=6, clip_on=True) # ── Highlighted / selected glycan ───────────────────────────────── if self.highlighted_mz is not None: hmz = self.highlighted_mz if lo <= hmz <= hi: ax.axvspan(hmz - self._tol, hmz + self._tol, color=PALETTE["highlight"], alpha=0.22, zorder=4) ax.axvline(hmz, color=PALETTE["highlight"], linewidth=1.8, linestyle="-", alpha=0.95, zorder=5) ax.text( hmz, max_int * 1.01, self.highlighted_label, color=PALETTE["highlight"], fontsize=7.5, ha="center", va="bottom", rotation=90, clip_on=True, ) # ── Unregistered peak selection (cyan band) ──────────────────────── if self._unreg_mode and self._unreg_mz is not None: umz = self._unreg_mz if lo <= umz <= hi: ax.axvspan(umz - self._unreg_tol, umz + self._unreg_tol, color="#89dceb", alpha=0.25, zorder=4) ax.axvline(umz, color="#89dceb", linewidth=1.8, linestyle="-", alpha=0.95, zorder=5) ax.text( umz, max_int * 1.01, f"{umz:.4f}", color="#89dceb", fontsize=7.5, ha="center", va="bottom", rotation=90, clip_on=True, ) # ── Legend ──────────────────────────────────────────────────────── handles = [ Line2D([0], [0], color=spec_col, linewidth=1.5, label="Full spectrum" if src == "Raw imzML (full)" else "sdata mean"), ] if self.show_peaks_cb.isChecked(): handles.append( Line2D([0], [0], color=PALETTE["peak_marker"], linewidth=1, linestyle="--", label=f"Curated peaks (n={len(self.peaks)})") ) if self.highlighted_mz is not None: handles.append( mpatches.Patch(color=PALETTE["highlight"], alpha=0.5, label=f"Selected: {self.highlighted_label}") ) if self.show_applied_tol_cb.isChecked(): handles.append( Line2D([0], [0], color="red", linewidth=1.2, label=f"Applied tolerance (\u00b1{self.applied_tolerance:.3f} Da)") ) if self._unreg_mode and self._unreg_mz is not None: handles.append( mpatches.Patch(color="#89dceb", alpha=0.5, label=f"Unregistered: {self._unreg_mz:.4f} ± {self._unreg_tol:.3f}") ) ax.legend(handles=handles, fontsize=7.5, framealpha=0.3, loc="upper right", facecolor=PALETTE["surface"], labelcolor=PALETTE["text"]) ax.set_xlabel("m/z (Da)", fontsize=9) ax.set_ylabel("Intensity", fontsize=9) ax.set_xlim(lo, hi) ax.set_ylim(bottom=0) ax.xaxis.set_minor_locator(mticker.AutoMinorLocator(5)) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) self.canvas.fig.tight_layout(pad=0.4) self.canvas.draw() # ── Public API ────────────────────────────────────────────────────────
[docs] def highlight_glycan(self, mz: float, label: str, tol: float = 0.15): """Highlight a glycan in the spectrum and scroll it into view.""" self.highlighted_mz = mz self.highlighted_label = label self._tol = tol # Centre the view on this peak, keeping current span or defaulting to 400 Da span = max(self._view_hi - self._view_lo, 50) half = span / 2 mz_full = self._current_mz() if mz_full is not None: data_lo, data_hi = float(mz_full.min()), float(mz_full.max()) new_lo = max(data_lo, mz - half) new_hi = min(data_hi, mz + half) self._view_lo = new_lo self._view_hi = new_hi self.mz_lo.setValue(new_lo) self.mz_hi.setValue(new_hi) self._redraw()
[docs] def clear_highlight(self): self.highlighted_mz = None self.highlighted_label = "" self._redraw()
# ════════════════════════════════════════════════════════════════════════════ # Glycan selection dialog (heatmap) # ════════════════════════════════════════════════════════════════════════════
[docs] class GlycanSelectionDialog(QDialog): """Popup to choose glycans for heatmap plotting via select, type, or upload.""" def __init__( self, glycan_names: list[str], peaks: list[float], label_to_mz: dict[str, float], parent=None, ): super().__init__(parent)
[docs] self.glycan_names = glycan_names
[docs] self.peaks = peaks
[docs] self.label_to_mz = label_to_mz
[docs] self.selected_indices: list[int] = []
self._upload_path: Optional[str] = None self.setWindowTitle("Select Glycans for Heatmap") self.setMinimumWidth(360) self.setStyleSheet(BASE_STYLE) layout = QVBoxLayout(self)
[docs] self.tabs = QTabWidget()
layout.addWidget(self.tabs) # ── Tab 1: checklist ──────────────────────────────────────────── select_w = QWidget() select_layout = QVBoxLayout(select_w) select_hint = QLabel("Check one or more glycans:") select_hint.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 9px;") select_layout.addWidget(select_hint)
[docs] self.glycan_list = QListWidget()
for name in glycan_names: item = QListWidgetItem(name) item.setFlags(item.flags() | Qt.ItemIsUserCheckable) item.setCheckState(Qt.Unchecked) self.glycan_list.addItem(item) select_layout.addWidget(self.glycan_list) self.tabs.addTab(select_w, "Select") # ── Tab 2: comma-separated text ───────────────────────────────── type_w = QWidget() type_layout = QVBoxLayout(type_w) type_hint = QLabel("Enter glycan names or m/z values, separated by commas:") type_hint.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 9px;") type_hint.setWordWrap(True) type_layout.addWidget(type_hint)
[docs] self.type_edit = QPlainTextEdit()
self.type_edit.setPlaceholderText("e.g. HexNAc, 1685.65, Fuc-HexNAc (933.40)") self.type_edit.setMinimumHeight(100) type_layout.addWidget(self.type_edit) self.tabs.addTab(type_w, "Type") # ── Tab 3: file upload ────────────────────────────────────────── upload_w = QWidget() upload_layout = QVBoxLayout(upload_w) upload_hint = QLabel( "Upload a .txt or .csv file with one glycan per line, or comma-separated." ) upload_hint.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 9px;") upload_hint.setWordWrap(True) upload_layout.addWidget(upload_hint) upload_row = QHBoxLayout()
[docs] self.upload_lbl = QLabel("No file selected")
self.upload_lbl.setStyleSheet(f"color: {PALETTE['text_dim']};") self.upload_lbl.setWordWrap(True) upload_row.addWidget(self.upload_lbl, stretch=1) browse_btn = QPushButton("Browse…") browse_btn.clicked.connect(self._browse_file) upload_row.addWidget(browse_btn) upload_layout.addLayout(upload_row) self.tabs.addTab(upload_w, "Upload")
[docs] self.summary_lbl = QLabel("")
self.summary_lbl.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 9px;") layout.addWidget(self.summary_lbl) buttons = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) buttons.accepted.connect(self._on_accept) buttons.rejected.connect(self.reject) layout.addWidget(buttons) ok_btn = buttons.button(QDialogButtonBox.Ok) if ok_btn is not None: ok_btn.setText("Enter") ok_btn.setDefault(True) ok_btn.setAutoDefault(True) def _browse_file(self): path, _ = QFileDialog.getOpenFileName( self, "Select glycan list", "", "Text files (*.txt *.csv);;All files (*)", ) if path: self._upload_path = path self.upload_lbl.setText(path) self.upload_lbl.setStyleSheet(f"color: {PALETTE['text']};") def _tokens_from_active_tab(self) -> list[str]: tab = self.tabs.currentIndex() if tab == 0: return [ self.glycan_list.item(i).text() for i in range(self.glycan_list.count()) if self.glycan_list.item(i).checkState() == Qt.Checked ] if tab == 1: raw = self.type_edit.toPlainText() return [t.strip() for t in raw.replace("\n", ",").split(",") if t.strip()] if tab == 2 and self._upload_path: try: with open(self._upload_path, encoding="utf-8") as fh: content = fh.read() tokens = [] for line in content.splitlines(): tokens.extend(t.strip() for t in line.split(",") if t.strip()) return tokens except Exception as e: show_info(f"Could not read file: {e}") return [] def _resolve_tokens(self, tokens: list[str]) -> list[int]: indices: list[int] = [] seen: set[int] = set() for token in tokens: idx = self._match_token(token) if idx is not None and idx not in seen: indices.append(idx) seen.add(idx) return indices def _match_token(self, token: str) -> Optional[int]: token_l = token.strip().lower() for i, name in enumerate(self.glycan_names): if name.lower() == token_l or token_l in name.lower(): return i if token in self.label_to_mz: mz = self.label_to_mz[token] for i, pk in enumerate(self.peaks): if abs(pk - mz) < 0.5: return i try: mz_val = float(token_l.replace("mz-", "").strip()) best_i, best_d = None, float("inf") for i, pk in enumerate(self.peaks): d = abs(pk - mz_val) if d < best_d: best_d, best_i = d, i if best_i is not None and best_d < 0.5: return best_i except ValueError: pass return None def _on_accept(self): tokens = self._tokens_from_active_tab() if not tokens: show_info("No glycans selected. Check items, type a list, or upload a file.") return self.selected_indices = self._resolve_tokens(tokens) if not self.selected_indices: show_info("Could not match any glycans. Check names or m/z values.") return self.summary_lbl.setText(f"{len(self.selected_indices)} glycan(s) selected.") self.accept()
[docs] def keyPressEvent(self, event): if event.key() in (Qt.Key_Return, Qt.Key_Enter): self._on_accept() else: super().keyPressEvent(event)
# ════════════════════════════════════════════════════════════════════════════ # 2. ANALYSIS SIDEBAR (right dock) # ════════════════════════════════════════════════════════════════════════════
[docs] class AnalysisSidebar(QWidget): """ Right-dock panel. Glycan tab: selector + Violin/Box/Histogram (no spatial scatter — that is now rendered on the napari viewer directly). """
[docs] glycan_selected = Signal(float, str) # mz, display_label
def __init__( self, sdata: SpatialData, peaks: list[float], viewer: napari.Viewer, glycan_df: Optional[pd.DataFrame] = None, table_name: str = "maldi_adata", parent=None, ): super().__init__(parent)
[docs] self.sdata = sdata
[docs] self.peaks = sorted(peaks)
[docs] self.viewer = viewer
[docs] self.glycan_df = glycan_df
[docs] self.table_name = table_name
[docs] self.hmap_custom_indices: Optional[list[int]] = None
self._last_glycan_render: Optional[dict] = None self._last_meta_render: Optional[dict] = None self._build_glycan_lookup() self._build_ui() # ── Data helpers ────────────────────────────────────────────────────── def _on_cmap_changed(self, cmap_name: str): """Re-apply colormap after a glycan render, and refresh the sidebar legend.""" if not cmap_name or self._last_glycan_render is None: return layer = _find_shapes_layer(self.viewer, self._last_glycan_render["shapes_name"]) if layer is None: return try: _apply_shapes_colormap(layer, cmap_name, categorical=False) self._last_glycan_render["colormap"] = cmap_name self._draw_glycan_legend() except Exception as e: show_info(f"Could not apply colormap: {e}") def _draw_glycan_legend(self): state = self._last_glycan_render if state is None: return _draw_spatial_legend( self.glycan_legend_canvas, state["label"], categorical=False, colormap=state["colormap"], vmin=state["vmin"], vmax=state["vmax"], ) def _draw_meta_legend(self): state = self._last_meta_render if state is None: return _draw_spatial_legend( self.meta_legend_canvas, state["label"], categorical=state["categorical"], colormap=state["colormap"], categories=state.get("categories"), vmin=state.get("vmin"), vmax=state.get("vmax"), ) def _build_glycan_lookup(self): self.mz_to_label: dict[float, str] = {} self.label_to_mz: dict[str, float] = {} self.glycan_names: list[str] = [] try: adata = self.sdata.tables[self.table_name] mz_arr = _resolve_mz_array(adata) disp = _resolve_var_display_labels(adata) for mz, lbl in zip(mz_arr, disp): self.mz_to_label[mz] = lbl self.label_to_mz[lbl] = mz except Exception: pass if self.glycan_df is not None: for _, row in self.glycan_df.iterrows(): mz = float(row["mz"]) lbl = str(row["label"]) if lbl and lbl not in ("nan", ""): self.mz_to_label[mz] = lbl self.label_to_mz[lbl] = mz names = [] for pk in self.peaks: best_label = f"{pk:.4f}" best_dist = float("inf") if self.glycan_df is not None: for _, row in self.glycan_df.iterrows(): d = abs(float(row["mz"]) - pk) if d < best_dist and d < 0.5: lbl = str(row["label"]) if lbl and lbl not in ("nan", ""): best_dist = d best_label = f"{lbl} ({pk:.2f})" if best_dist == float("inf"): nearest = min(self.mz_to_label.keys(), key=lambda m: abs(m - pk), default=None) if nearest is not None and abs(nearest - pk) < 0.5: lbl = self.mz_to_label[nearest] if not _looks_numeric(lbl): best_label = f"{lbl} ({pk:.2f})" names.append(best_label) self.glycan_names = names def _adata(self): return self.sdata.tables[self.table_name] def _get_peak_index(self, peak_mz: float) -> Optional[int]: adata = self._adata() try:
[docs] var_mzs = _resolve_mz_array(adata)
idx = int(np.argmin(np.abs(var_mzs - peak_mz))) if abs(var_mzs[idx] - peak_mz) < 0.5: return idx except Exception: pass return None # ── UI ──────────────────────────────────────────────────────────────── def _build_ui(self): layout = QVBoxLayout(self) layout.setContentsMargins(6, 6, 6, 6) layout.setSpacing(6) title = QLabel("goatpy Analysis") title.setStyleSheet( f"font-size: 14px; font-weight: bold; color: {PALETTE['accent']}; padding: 4px 0;" ) layout.addWidget(title) self.tabs = QTabWidget() layout.addWidget(self.tabs) self.tabs.addTab(self._build_glycan_tab(), "Glycan") self.tabs.addTab(self._build_metadata_tab(), "Metadata") self.tabs.addTab(self._build_umap_tab(), "UMAP") self.tabs.addTab(self._build_heatmap_tab(), "Heatmap") self.tabs.addTab(self._build_annotations_tab(), "Annotations") self.tabs.addTab(self._build_stats_tab(), "Stats") # ── Glycan tab ──────────────────────────────────────────────────────── def _build_glycan_tab(self) -> QWidget: w = QWidget() layout = QVBoxLayout(w) layout.setContentsMargins(4, 4, 4, 4) layout.setSpacing(6) # Selector sel_grp = QGroupBox("Select Glycan / m/z") sel_layout = QVBoxLayout(sel_grp) self.glycan_search = QComboBox() self.glycan_search.setEditable(True) self.glycan_search.setInsertPolicy(QComboBox.NoInsert) self.glycan_search.addItems(self.glycan_names) self.glycan_search.setCurrentIndex(0) sel_layout.addWidget(self.glycan_search) show_btn = QPushButton("Show on H&E Viewer") show_btn.clicked.connect(self._on_glycan_show) # In _build_glycan_tab, after the "show_btn" line, add: cmap_row = QHBoxLayout() cmap_row.addWidget(QLabel("Ion map colormap:")) self.ion_cmap_combo = QComboBox() self.ion_cmap_combo.addItems(CONTINUOUS_CMAPS) self.ion_cmap_combo.setCurrentText("inferno") self.ion_cmap_combo.currentTextChanged.connect(self._on_cmap_changed) cmap_row.addWidget(self.ion_cmap_combo) sel_layout.addLayout(cmap_row) sel_layout.addWidget(show_btn) note = QLabel("Ion map is added as a napari layer over the H&E image.") note.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 9px;") note.setWordWrap(True) sel_layout.addWidget(note) layout.addWidget(sel_grp) legend_grp = QGroupBox("Spatial legend") legend_layout = QVBoxLayout(legend_grp) self.glycan_legend_canvas = MplCanvas(w, width=4.5, height=1.6, dpi=90) self.glycan_legend_canvas.setMinimumHeight(70) legend_layout.addWidget(self.glycan_legend_canvas) layout.addWidget(legend_grp) # Distribution plot plot_grp = QGroupBox("Distribution Plot") plot_layout = QVBoxLayout(plot_grp) row1 = QHBoxLayout() row1.addWidget(QLabel("Type:")) self.glycan_plot_type = QComboBox() self.glycan_plot_type.addItems(["Violin by cluster", "Box by cluster", "Histogram"]) row1.addWidget(self.glycan_plot_type) plot_layout.addLayout(row1) row2 = QHBoxLayout() row2.addWidget(QLabel("Group:")) self.cluster_col_combo = QComboBox() self._populate_obs_categoricals(self.cluster_col_combo) row2.addWidget(self.cluster_col_combo) plot_layout.addLayout(row2) dist_btn = QPushButton("Plot Distribution") dist_btn.clicked.connect(self._on_dist_plot) plot_layout.addWidget(dist_btn) layout.addWidget(plot_grp) self.glycan_canvas = MplCanvas(w, width=4.5, height=3.8, dpi=90) layout.addWidget(self.glycan_canvas) return w def _populate_obs_categoricals(self, combo: QComboBox): try: obs = self._adata().obs cats = [c for c in obs.columns if obs[c].dtype.name == "category" or c in ("GPCA_clusters", "leiden", "batch", "annotation")] combo.clear() combo.addItems(cats if cats else ["(none)"]) except Exception: combo.addItems(["(none)"]) def _populate_obs_columns(self, combo: QComboBox): try: cols = list(self._adata().obs.columns) combo.clear() combo.addItems(cols if cols else ["(none)"]) except Exception: combo.addItems(["(none)"]) def _populate_shapes_elements(self, combo: QComboBox): try: names = list(self.sdata.shapes.keys()) combo.clear() combo.addItems(names if names else ["pixels"]) except Exception: combo.addItems(["pixels"]) # ── Metadata tab ────────────────────────────────────────────────────── def _build_metadata_tab(self) -> QWidget: w = QWidget() layout = QVBoxLayout(w) layout.setContentsMargins(4, 4, 4, 4) layout.setSpacing(6) sel_grp = QGroupBox("Plot metadata on shapes") sel_layout = QVBoxLayout(sel_grp) shapes_row = QHBoxLayout() shapes_row.addWidget(QLabel("Shapes layer:")) self.meta_shapes_combo = QComboBox() self._populate_shapes_elements(self.meta_shapes_combo) self.meta_shapes_combo.currentTextChanged.connect( lambda name: self._populate_meta_columns(self.meta_col_combo, name) ) shapes_row.addWidget(self.meta_shapes_combo) sel_layout.addLayout(shapes_row) col_row = QHBoxLayout() col_row.addWidget(QLabel("metadata column:")) self.meta_col_combo = QComboBox() self._populate_meta_columns(self.meta_col_combo, self.meta_shapes_combo.currentText()) self.meta_col_combo.currentTextChanged.connect(self._on_meta_col_changed) col_row.addWidget(self.meta_col_combo) sel_layout.addLayout(col_row) self.meta_type_lbl = QLabel("Type: —") self.meta_type_lbl.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 9px;") sel_layout.addWidget(self.meta_type_lbl) cmap_row = QHBoxLayout() cmap_row.addWidget(QLabel("Colormap:")) self.meta_cmap_combo = QComboBox() self.meta_cmap_combo.addItems(CONTINUOUS_CMAPS) self.meta_cmap_combo.currentTextChanged.connect(self._on_meta_cmap_changed) cmap_row.addWidget(self.meta_cmap_combo) sel_layout.addLayout(cmap_row) show_btn = QPushButton("Show on H&E Viewer") show_btn.clicked.connect(self._on_metadata_show) sel_layout.addWidget(show_btn) note = QLabel( "Colour the selected shapes layer by metadata values. " "Categorical values use discrete colormaps; numeric values use continuous ones." ) note.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 9px;") note.setWordWrap(True) sel_layout.addWidget(note) layout.addWidget(sel_grp) legend_grp = QGroupBox("Spatial legend") legend_layout = QVBoxLayout(legend_grp) self.meta_legend_canvas = MplCanvas(w, width=4.5, height=1.6, dpi=90) self.meta_legend_canvas.setMinimumHeight(70) legend_layout.addWidget(self.meta_legend_canvas) layout.addWidget(legend_grp) layout.addStretch() self._on_meta_col_changed(self.meta_col_combo.currentText()) return w def _on_meta_col_changed(self, col: str): if col == "(none)": self.meta_type_lbl.setText("Type: —") return shapes_name = self.meta_shapes_combo.currentText() series = self._get_metadata_series(shapes_name, col) if series is None: self.meta_type_lbl.setText("Type: —") return is_cat = _safe_is_series_categorical(series) self.meta_type_lbl.setText( f"Type: {'categorical' if is_cat else 'continuous'}" ) cmap = self.meta_cmap_combo.currentText() self.meta_cmap_combo.blockSignals(True) self.meta_cmap_combo.clear() self.meta_cmap_combo.addItems(CATEGORICAL_CMAPS if is_cat else CONTINUOUS_CMAPS) idx = self.meta_cmap_combo.findText(cmap) self.meta_cmap_combo.setCurrentIndex(idx if idx >= 0 else 0) self.meta_cmap_combo.blockSignals(False) def _get_metadata_series(self, shapes_name: str, col: str): if shapes_name == "pixels": adata = self._adata() return adata.obs[col] if col in adata.obs.columns else None if shapes_name not in self.sdata.shapes: return None data = _shape_to_dataframe(self.sdata.shapes[shapes_name]) if data is None or col not in data.columns: return None series = data[col] # 'colour' columns may be stored as hex strings (good) or list/array RGB # (legacy). Convert legacy list-form to hex so downstream categorical # detection and colormap logic works correctly. if col == "colour": sample = series.iloc[0] if len(series) else None if isinstance(sample, (list, tuple, np.ndarray)): series = series.apply(_rgb_to_hex) return series def _get_shape_metadata_columns(self, shapes_name: str) -> list[str]: if shapes_name == "pixels": return list(self._adata().obs.columns) if shapes_name not in self.sdata.shapes: return [] data = _shape_to_dataframe(self.sdata.shapes[shapes_name]) if data is None: return [] return [c for c in data.columns if c != "geometry"] def _populate_meta_columns(self, combo: QComboBox, shapes_name: str): combo.blockSignals(True) combo.clear() columns = self._get_shape_metadata_columns(shapes_name) if not columns: combo.addItem("(none)") else: combo.addItems(columns) combo.blockSignals(False) if hasattr(self, "meta_cmap_combo"): self._on_meta_col_changed(combo.currentText()) def _on_meta_cmap_changed(self, cmap_name: str): """Re-apply colormap after a metadata render, and refresh the sidebar legend.""" if not cmap_name or self._last_meta_render is None: return layer = _find_shapes_layer(self.viewer, self._last_meta_render["shapes_name"]) if layer is None: return try: n_cat = len(self._last_meta_render.get("categories") or []) _apply_shapes_colormap( layer, cmap_name, categorical=self._last_meta_render["categorical"], n_categories=n_cat, ) self._last_meta_render["colormap"] = cmap_name self._draw_meta_legend() except Exception as e: show_info(f"Could not apply colormap: {e}") def _on_metadata_show(self): col = self.meta_col_combo.currentText() shapes_name = self.meta_shapes_combo.currentText() if col == "(none)": show_info("Select a valid metadata column.") return series = self._get_metadata_series(shapes_name, col) if series is None: show_info(f"Column '{col}' not found for shapes layer '{shapes_name}'.") return values = series.values is_cat = _safe_is_series_categorical(series) cmap = self.meta_cmap_combo.currentText() state = _render_values_on_shapes( self.viewer, values, cmap, is_cat, col, shapes_name, ) if state is not None: self._last_meta_render = state self._draw_meta_legend() def _on_glycan_show(self): idx = self.glycan_search.currentIndex() if idx < 0 or idx >= len(self.peaks): return pk = self.peaks[idx] label = self.glycan_names[idx] self._render_on_viewer(pk, label) self.glycan_selected.emit(pk, label) def _on_dist_plot(self): idx = self.glycan_search.currentIndex() if idx < 0 or idx >= len(self.peaks): return pk = self.peaks[idx] label = self.glycan_names[idx] self._draw_distribution(pk, label) def _render_on_viewer(self, peak_mz: float, label: str): state = _render_glycan_on_viewer( self.viewer, self.sdata, peak_mz, label, self.table_name, colormap=self.ion_cmap_combo.currentText(), ) if state is not None: self._last_glycan_render = state self._draw_glycan_legend() def _draw_distribution(self, peak_mz: float, label: str): col_idx = self._get_peak_index(peak_mz) if col_idx is None: show_info(f"Peak {peak_mz:.2f} not found.") return adata = self._adata() X = np.asarray(adata.X, dtype=np.float32) values = X[:, col_idx] short = label.split("(")[0].strip()[:28] cluster_col = self.cluster_col_combo.currentText() plot_type = self.glycan_plot_type.currentText() self.glycan_canvas.fig.clear() ax = self.glycan_canvas.fig.add_subplot(111) self.glycan_canvas._style_ax(ax) if plot_type == "Histogram": ax.hist(values[values > 0], bins=50, color=PALETTE["accent"], edgecolor="none", alpha=0.8) ax.set_xlabel("Intensity", fontsize=8) ax.set_ylabel("# Pixels", fontsize=8) ax.set_title(short, fontsize=9) else: self._plot_by_cluster( ax, adata, values, cluster_col, kind="violin" if "Violin" in plot_type else "box", label=short, ) self.glycan_canvas.fig.tight_layout(pad=0.5) self.glycan_canvas.draw() def _plot_by_cluster(self, ax, adata, values, cluster_col, kind, label): if cluster_col == "(none)" or cluster_col not in adata.obs.columns: ax.text(0.5, 0.5, "No cluster column\nfound", ha="center", va="center", transform=ax.transAxes, color=PALETTE["text_dim"]) return clusters = adata.obs[cluster_col].astype(str).values unique = sorted(set(clusters)) cmap = plt.get_cmap("tab20", len(unique)) data_by_cluster = [values[clusters == c] for c in unique] if kind == "violin": parts = ax.violinplot(data_by_cluster, positions=range(len(unique)), showmedians=True, showextrema=False) for i, pc in enumerate(parts["bodies"]): pc.set_facecolor(cmap(i)); pc.set_alpha(0.7) parts["cmedians"].set_color(PALETTE["text"]) else: bp = ax.boxplot(data_by_cluster, positions=range(len(unique)), patch_artist=True, showfliers=False, medianprops={"color": PALETTE["text"], "linewidth": 1.5}, whiskerprops={"color": PALETTE["text_dim"]}, capprops={"color": PALETTE["text_dim"]}) for i, patch in enumerate(bp["boxes"]): patch.set_facecolor(cmap(i)); patch.set_alpha(0.75) ax.set_xticks(range(len(unique))) ax.set_xticklabels(unique, rotation=45, ha="right", fontsize=7) ax.set_ylabel("Intensity", fontsize=8) ax.set_title(f"{label} by {cluster_col}", fontsize=8.5) # ── Unregistered peak: display ion image spatially ─────────────────────
[docs] def display_unregistered_ion_image(self, mz: float, tol: float): """Compute per-pixel intensities for an arbitrary m/z (+/- tol) from the raw imzML and render them on the 'pixels' Shapes layer, exactly like the curated glycan ion maps (guarantees H&E alignment).""" try: path = self.sdata.tables[self.table_name].uns.get("maldi_path") except Exception: path = None if not path: show_info("No raw imzML path found (uns['maldi_path']) — cannot compute ion image.") return label = f"m/z {mz:.4f}" state = _display_unregistered_ion_image_on_shapes( self.viewer, self.sdata, path, mz, tol, label=label, table_name=self.table_name, colormap=self.ion_cmap_combo.currentText(), ) if state is not None: self._last_glycan_render = state self._draw_glycan_legend()
# ── Select glycan from spectrum click ─────────────────────────────────
[docs] def select_peak_from_spectrum(self, mz: float, label: str): """Called when the user clicks a peak in the spectrum widget.""" # Update combo box to match if possible for i, pk in enumerate(self.peaks): if abs(pk - mz) < 0.5: self.glycan_search.setCurrentIndex(i) break self._render_on_viewer(mz, label)
# ── UMAP tab ────────────────────────────────────────────────────────── def _build_umap_tab(self) -> QWidget: w = QWidget() layout = QVBoxLayout(w) layout.setContentsMargins(4, 4, 4, 4) col_grp = QGroupBox("Colour By") col_layout = QVBoxLayout(col_grp) self.umap_color_type = QComboBox() self.umap_color_type.addItems(["Metadata column", "Glycan intensity"]) self.umap_color_type.currentTextChanged.connect(self._toggle_umap_controls) col_layout.addWidget(self.umap_color_type) self.umap_meta_col = QComboBox() self._populate_obs_columns(self.umap_meta_col) self.umap_meta_col.currentTextChanged.connect(self._on_umap_meta_col_changed) col_layout.addWidget(self.umap_meta_col) self.umap_glycan_combo = QComboBox() self.umap_glycan_combo.addItems(self.glycan_names) self.umap_glycan_combo.setVisible(False) col_layout.addWidget(self.umap_glycan_combo) cmap_row = QHBoxLayout() cmap_row.addWidget(QLabel("Colormap:")) self.umap_cmap_combo = QComboBox() self.umap_cmap_combo.addItems(CONTINUOUS_CMAPS) self.umap_cmap_combo.setCurrentText("inferno") cmap_row.addWidget(self.umap_cmap_combo) col_layout.addLayout(cmap_row) layout.addWidget(col_grp) emb_grp = QGroupBox("Embedding") emb_layout = QHBoxLayout(emb_grp) self.umap_embedding_combo = QComboBox() self._populate_obsm(self.umap_embedding_combo) emb_layout.addWidget(self.umap_embedding_combo) layout.addWidget(emb_grp) plot_btn = QPushButton("Plot UMAP") plot_btn.clicked.connect(self._draw_umap) layout.addWidget(plot_btn) self.umap_canvas = MplCanvas(w, width=4.5, height=4.5, dpi=90) layout.addWidget(self.umap_canvas) # Initialise colormap list for the starting metadata selection. self._on_umap_meta_col_changed(self.umap_meta_col.currentText()) return w def _toggle_umap_controls(self, text): is_glycan = text == "Glycan intensity" self.umap_glycan_combo.setVisible(is_glycan) self.umap_meta_col.setVisible(not is_glycan) self.umap_cmap_combo.setVisible(True) if not is_glycan: self._on_umap_meta_col_changed(self.umap_meta_col.currentText()) else: self.umap_cmap_combo.blockSignals(True) self.umap_cmap_combo.clear() self.umap_cmap_combo.addItems(CONTINUOUS_CMAPS) self.umap_cmap_combo.setCurrentText("inferno") self.umap_cmap_combo.blockSignals(False) def _on_umap_meta_col_changed(self, col: str): if col == "(none)": return adata = self._adata() if col not in adata.obs.columns: return is_cat = _is_obs_categorical(adata, col) current = self.umap_cmap_combo.currentText() self.umap_cmap_combo.blockSignals(True) self.umap_cmap_combo.clear() self.umap_cmap_combo.addItems(CATEGORICAL_CMAPS if is_cat else CONTINUOUS_CMAPS) if current in (CATEGORICAL_CMAPS if is_cat else CONTINUOUS_CMAPS): self.umap_cmap_combo.setCurrentText(current) self.umap_cmap_combo.blockSignals(False) def _populate_obsm(self, combo: QComboBox): try: keys = list(self._adata().obsm.keys()) combo.addItems(keys if keys else ["(none)"]) except Exception: combo.addItems(["(none)"]) def _draw_umap(self): adata = self._adata() emb_key = self.umap_embedding_combo.currentText() if emb_key == "(none)" or emb_key not in adata.obsm: show_info("No embedding found. Run graphpca_spatialdata() first.") return emb = adata.obsm[emb_key] if emb.shape[1] == 2: coords = emb else: try: from umap import UMAP coords = UMAP(n_components=2, random_state=42).fit_transform(emb) except ImportError: coords = emb[:, :2] color_type = self.umap_color_type.currentText() self.umap_canvas.fig.clear() ax = self.umap_canvas.fig.add_subplot(111) self.umap_canvas._style_ax(ax) if color_type == "Metadata column": col = self.umap_meta_col.currentText() if col != "(none)" and col in adata.obs.columns: cmap_name = self.umap_cmap_combo.currentText() labels = adata.obs[col] if _is_obs_categorical(adata, col): labels = labels.astype(str).values unique = sorted(set(labels)) cmap = plt.get_cmap(cmap_name, len(unique)) for i, cl in enumerate(unique): mask = labels == cl ax.scatter(coords[mask, 0], coords[mask, 1], s=1, alpha=0.6, color=cmap(i), label=cl, rasterized=True) ax.legend(markerscale=4, fontsize=6.5, framealpha=0.3, facecolor=PALETTE["surface"], labelcolor=PALETTE["text"], bbox_to_anchor=(1, 1), loc="upper left") else: values = pd.to_numeric(labels, errors="coerce").astype(np.float32) valid = values[np.isfinite(values)] vmin = float(np.percentile(valid, 1)) if len(valid) else 0.0 vmax = float(np.percentile(valid, 99)) if len(valid) else 1.0 sc = ax.scatter(coords[:, 0], coords[:, 1], c=values, cmap=cmap_name, s=1, alpha=0.6, vmin=vmin, vmax=vmax, rasterized=True) self.umap_canvas.fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.01) ax.set_title(f"UMAP — {col}", fontsize=8.5) else: ax.scatter(coords[:, 0], coords[:, 1], s=1, alpha=0.4, color=PALETTE["accent"], rasterized=True) else: idx_g = self.umap_glycan_combo.currentIndex() if idx_g < len(self.peaks): pk = self.peaks[idx_g] col_idx = self._get_peak_index(pk) if col_idx is not None: v = np.asarray(adata.X, dtype=np.float32)[:, col_idx] sc = ax.scatter(coords[:, 0], coords[:, 1], c=v, cmap="inferno", s=1, alpha=0.6, vmin=np.percentile(v, 1), vmax=np.percentile(v, 99), rasterized=True) self.umap_canvas.fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.01) ax.set_title( f"UMAP — {self.glycan_names[idx_g].split('(')[0][:20]}", fontsize=8.5 ) ax.set_xlabel("UMAP 1", fontsize=8) ax.set_ylabel("UMAP 2", fontsize=8) self.umap_canvas.fig.tight_layout(pad=0.5) self.umap_canvas.draw() # ── Heatmap tab ─────────────────────────────────────────────────────── def _build_heatmap_tab(self) -> QWidget: w = QWidget() layout = QVBoxLayout(w) layout.setContentsMargins(4, 4, 4, 4) ctrl_grp = QGroupBox("Options") ctrl_layout = QVBoxLayout(ctrl_grp) r1 = QHBoxLayout() r1.addWidget(QLabel("Group by:")) self.hmap_group_col = QComboBox() self._populate_obs_categoricals(self.hmap_group_col) r1.addWidget(self.hmap_group_col) ctrl_layout.addLayout(r1) r2 = QHBoxLayout() r2.addWidget(QLabel("Top N peaks:")) self.hmap_topn = QSpinBox() self.hmap_topn.setRange(5, 200) self.hmap_topn.setValue(30) r2.addWidget(self.hmap_topn) ctrl_layout.addLayout(r2) r3 = QHBoxLayout() r3.addWidget(QLabel("Normalise:")) self.hmap_norm_combo = QComboBox() self.hmap_norm_combo.addItems(["z-score", "min-max", "none"]) r3.addWidget(self.hmap_norm_combo) ctrl_layout.addLayout(r3) glycan_row = QHBoxLayout() self.hmap_glycan_lbl = QLabel("Glycans: top N (default)") self.hmap_glycan_lbl.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 9px;") self.hmap_glycan_lbl.setWordWrap(True) glycan_row.addWidget(self.hmap_glycan_lbl, stretch=1) glycan_btn = QPushButton("Select Glycans…") glycan_btn.clicked.connect(self._open_glycan_selection) glycan_row.addWidget(glycan_btn) ctrl_layout.addLayout(glycan_row) clear_btn = QPushButton("Clear Selection") clear_btn.clicked.connect(self._clear_glycan_selection) ctrl_layout.addWidget(clear_btn) layout.addWidget(ctrl_grp) plot_btn = QPushButton("Plot Heatmap") plot_btn.clicked.connect(self._draw_heatmap) layout.addWidget(plot_btn) self.hmap_canvas = MplCanvas(w, width=4.5, height=5.5, dpi=90) layout.addWidget(self.hmap_canvas) return w def _open_glycan_selection(self): dlg = GlycanSelectionDialog( self.glycan_names, self.peaks, self.label_to_mz, parent=self, ) if dlg.exec_() == QDialog.Accepted: self.hmap_custom_indices = dlg.selected_indices n = len(self.hmap_custom_indices) preview = ", ".join( self.glycan_names[i].split("(")[0].strip()[:16] for i in self.hmap_custom_indices[:4] ) suffix = f" … +{n - 4} more" if n > 4 else "" self.hmap_glycan_lbl.setText(f"Glycans: {n} selected ({preview}{suffix})") self.hmap_glycan_lbl.setStyleSheet(f"color: {PALETTE['text']}; font-size: 9px;") def _clear_glycan_selection(self): self.hmap_custom_indices = None self.hmap_glycan_lbl.setText("Glycans: top N (default)") self.hmap_glycan_lbl.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 9px;") def _draw_heatmap(self): adata = self._adata() group_col = self.hmap_group_col.currentText() top_n = self.hmap_topn.value() norm = self.hmap_norm_combo.currentText() if group_col == "(none)" or group_col not in adata.obs.columns: show_info("Select a valid grouping column first.") return X = np.asarray(adata.X, dtype=np.float32) labels = adata.obs[group_col].astype(str).values groups = sorted(set(labels)) means = np.array([X[labels == g].mean(axis=0) for g in groups]) if self.hmap_custom_indices: col_idx = self._resolve_heatmap_column_indices(adata) if not col_idx: return top_idx = np.array(col_idx, dtype=int) else: top_idx = np.argsort(means.var(axis=0))[::-1][:top_n] means_top = means[:, top_idx] if norm == "z-score": from scipy.stats import zscore with warnings.catch_warnings(): warnings.simplefilter("ignore") means_top = np.nan_to_num(zscore(means_top, axis=0)) elif norm == "min-max": mn = means_top.min(axis=0, keepdims=True) mx = means_top.max(axis=0, keepdims=True) with np.errstate(divide="ignore", invalid="ignore"): means_top = np.where(mx > mn, (means_top - mn) / (mx - mn), 0.0) disp = _resolve_var_display_labels(adata) col_labels = [disp[i][:12] for i in top_idx] self.hmap_canvas.fig.clear() ax = self.hmap_canvas.fig.add_subplot(111) self.hmap_canvas._style_ax(ax) im = ax.imshow(means_top, aspect="auto", cmap="RdBu_r", interpolation="nearest") self.hmap_canvas.fig.colorbar(im, ax=ax, shrink=0.8, pad=0.01, label=norm if norm != "none" else "intensity") ax.set_yticks(range(len(groups))) ax.set_yticklabels(groups, fontsize=7) ax.set_xticks(range(len(col_labels))) ax.set_xticklabels(col_labels, rotation=90, fontsize=5.5, ha="center") ax.set_title(f"Mean intensity by {group_col}", fontsize=8.5) self.hmap_canvas.fig.tight_layout(pad=0.5) self.hmap_canvas.draw() def _resolve_heatmap_column_indices(self, adata) -> list[int]: """Map selected peak indices to adata column indices.""" col_indices: list[int] = [] for peak_i in self.hmap_custom_indices or []: if peak_i < 0 or peak_i >= len(self.peaks): continue col_idx = self._get_peak_index(self.peaks[peak_i]) if col_idx is not None: col_indices.append(col_idx) if not col_indices: show_info("Selected glycans could not be matched to data columns.") return col_indices # ── Annotations tab ────────────────────────────────────────────────── def _build_annotations_tab(self) -> QWidget: w = QWidget() layout = QVBoxLayout(w) layout.setContentsMargins(4, 4, 4, 4) layout.setSpacing(6) info_lbl = QLabel("Edit annotation labels and colors") info_lbl.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 9px;") layout.addWidget(info_lbl) # Create table with columns: ID (hidden), Classification, Colour self.annotations_table = QTableWidget() self.annotations_table.setColumnCount(3) self.annotations_table.setHorizontalHeaderLabels(["ID", "Classification", "Colour"]) self.annotations_table.horizontalHeader().setStretchLastSection(False) self.annotations_table.setColumnHidden(0, True) self.annotations_table.setColumnWidth(1, 150) self.annotations_table.setColumnWidth(2, 80) layout.addWidget(self.annotations_table) # ── Outer container: New Annotation ──────────────────────────────── new_annotation_label = QLabel("Add New Annotation") new_annotation_label.setStyleSheet( f"color: {PALETTE['text']}; font-weight: bold; font-size: 13px; margin-top: 6px;" ) new_annotation_group = QWidget() new_annotation_layout = QVBoxLayout(new_annotation_group) new_annotation_layout.setContentsMargins(0, 0, 0, 0) new_annotation_layout.setSpacing(4) # ── Step 1: Create new annotation ───────────────────────────────── step1_group = QGroupBox("1. Assign new annotation label and colour") step1_group.setStyleSheet( f"QGroupBox {{ background-color: {PALETTE['surface']}; border: 2px solid {PALETTE['accent']}; " f"border-radius: 6px; margin-top: 8px; padding-top: 8px; }}" f"QGroupBox::title {{ subcontrol-origin: margin; left: 8px; padding: 0 4px; " f"color: {PALETTE['accent']}; font-weight: bold; }}" ) step1_layout = QVBoxLayout(step1_group) step1_layout.setContentsMargins(4,4,4,4) step1_layout.setSpacing(2) step1_group.adjustSize() step1_group.setSizePolicy( QSizePolicy.Preferred, QSizePolicy.Maximum ) add_row_layout = QHBoxLayout() add_row_layout.addWidget(QLabel("Label:")) self.new_annotation_name = QPlainTextEdit() self.new_annotation_name.setPlaceholderText("e.g. Tumor region") self.new_annotation_name.setMaximumHeight(28) add_row_layout.addWidget(self.new_annotation_name, stretch=1) self._new_annotation_colour = [255, 165, 0] # default orange new_colour_btn = QPushButton() new_colour_btn.setMaximumWidth(40) new_colour_btn.setStyleSheet( f"background-color: rgb(255,165,0); border: 1px solid {PALETTE['border']}; border-radius: 3px;" ) new_colour_btn.clicked.connect(self._on_pick_new_annotation_colour) self._new_annotation_colour_btn = new_colour_btn add_row_layout.addWidget(new_colour_btn) step1_layout.addLayout(add_row_layout) add_btn = QPushButton("Construct Annotation") add_btn.setStyleSheet( f"QPushButton {{ background-color: {PALETTE['accent']}; color: {PALETTE['bg']}; " f"font-weight: bold; padding: 6px; border-radius: 4px; }}" f"QPushButton:hover {{ background-color: {PALETTE['highlight']}; }}" ) add_btn.clicked.connect(self._on_add_annotation) step1_layout.addWidget(add_btn) new_annotation_layout.addWidget(step1_group) # ── Step 2: Draw annotation ─────────────────────────────────────── step2_group = QGroupBox("2. Draw your Annotation (using the rectangle or polygon tool)") step2_group.setStyleSheet( f"QGroupBox {{ background-color: {PALETTE['surface']}; border: 2px solid {PALETTE['success']}; " f"border-radius: 6px; margin-top: 8px; padding-top: 8px; }}" f"QGroupBox::title {{ subcontrol-origin: margin; left: 8px; padding: 0 4px; " f"color: {PALETTE['success']}; font-weight: bold; }}" ) step2_layout = QVBoxLayout(step2_group) step2_layout.setContentsMargins(8, 12, 8, 8) step2_layout.setSpacing(4) step2_hint = QLabel( "After clicking 'Construct Annotation' above, draw the region on the 'Annotations' " "layer in napari using the shape tools, then click 'Finish Annotation' below." ) step2_hint.setStyleSheet(f"color: {PALETTE['text_dim']}; font-size: 10px;") step2_hint.setWordWrap(True) step2_layout.addWidget(step2_hint) register_btn = QPushButton("Finish Annotation") register_btn.setStyleSheet( f"QPushButton {{ background-color: {PALETTE['success']}; color: {PALETTE['bg']}; " f"font-weight: bold; padding: 6px; border-radius: 4px; }}" f"QPushButton:hover {{ background-color: {PALETTE['highlight']}; }}" ) register_btn.clicked.connect(self._on_register_annotation) step2_layout.addWidget(register_btn) # Store references for state management self._step1_group = step1_group self._step2_group = step2_group self._step2_group.setEnabled(False) # Disabled by default new_annotation_layout.addWidget(step2_group) layout.addWidget(new_annotation_label) layout.addWidget(new_annotation_group) # ── Save and reload buttons ─────────────────────────────────────── button_layout = QHBoxLayout() save_btn = QPushButton("Save Changes") save_btn.clicked.connect(self._on_annotations_save) reload_btn = QPushButton("Reload") reload_btn.clicked.connect(self._on_annotations_reload) button_layout.addWidget(save_btn) button_layout.addWidget(reload_btn) layout.addLayout(button_layout) # Load annotations on tab init QTimer.singleShot(100, self._on_annotations_reload) return w def _on_annotations_reload(self): """Load annotations from sdata.shapes['Annotations'].""" try: if "Annotations" not in self.sdata.shapes: show_info("No Annotations layer found in SpatialData.") self.annotations_table.setRowCount(0) # Reset step 2 to disabled if hasattr(self, "_step2_group"): self._step2_group.setEnabled(False) return annotations = self.sdata.shapes["Annotations"] data = _shape_to_dataframe(annotations) if data is None: show_info("Annotations layer is not accessible as tabular data.") self.annotations_table.setRowCount(0) if hasattr(self, "_step2_group"): self._step2_group.setEnabled(False) return if "classification" in data.columns and "colour" in data.columns: classifications = data["classification"].astype(str).tolist() raw_colours = data["colour"].tolist() # Accept either hex strings or legacy [r,g,b] lists. colours = [] for c in raw_colours: if c is None: # Use default gray if colour is None colours.append([128, 128, 128]) elif isinstance(c, str): colours.append(_hex_to_rgb(c)) elif isinstance(c, (list, tuple)): # Ensure it's a proper RGB list colours.append([int(c[0]) if len(c) > 0 else 128, int(c[1]) if len(c) > 1 else 128, int(c[2]) if len(c) > 2 else 128]) else: colours.append([128, 128, 128]) ids = data["id"].astype(str).tolist() if "id" in data.columns else [str(i) for i in range(len(classifications))] else: attrs = getattr(annotations, "attrs", None) if not isinstance(attrs, dict) or "classification" not in attrs or "colour" not in attrs: show_info("Annotations missing 'classification' or 'colour' metadata.") self.annotations_table.setRowCount(0) if hasattr(self, "_step2_group"): self._step2_group.setEnabled(False) return classifications = attrs.get("classification", []) raw_colours = attrs.get("colour", []) colours = [] for c in raw_colours: if c is None: colours.append([128, 128, 128]) elif isinstance(c, str): colours.append(_hex_to_rgb(c)) elif isinstance(c, (list, tuple)): colours.append([int(c[0]) if len(c) > 0 else 128, int(c[1]) if len(c) > 1 else 128, int(c[2]) if len(c) > 2 else 128]) else: colours.append([128, 128, 128]) ids = attrs.get("id", [str(i) for i in range(len(classifications))]) self.annotations_table.setRowCount(len(classifications)) self._annotations_data = { "ids": ids, "classifications": list(classifications), "colours": colours, } for row, (ann_id, label, colour) in enumerate( zip(ids, classifications, colours) ): # ID column (hidden) id_item = QTableWidgetItem(str(ann_id)) id_item.setFlags(id_item.flags() & ~Qt.ItemIsEditable) self.annotations_table.setItem(row, 0, id_item) # Classification column (editable) label_item = QTableWidgetItem(str(label)) self.annotations_table.setItem(row, 1, label_item) # Colour column with color picker button color_widget = QWidget() color_layout = QHBoxLayout(color_widget) color_layout.setContentsMargins(2, 2, 2, 2) color_btn = QPushButton() color_btn.setMaximumWidth(60) rgb = colour[:3] if len(colour) >= 3 else [128, 128, 128] color_btn.setStyleSheet( f"background-color: rgb({int(rgb[0])}, {int(rgb[1])}, {int(rgb[2])}); " f"border: 1px solid {PALETTE['border']}; border-radius: 3px;" ) color_btn.clicked.connect(lambda checked, r=row: self._on_color_picker(r)) color_layout.addWidget(color_btn) color_layout.addStretch() self.annotations_table.setCellWidget(row, 2, color_widget) # Reset step 2 to disabled when reloading if hasattr(self, "_step2_group"): self._step2_group.setEnabled(False) except Exception as e: show_info(f"Error loading annotations: {e}") def _on_color_picker(self, row: int): """Open a color picker for a specific annotation.""" if not hasattr(self, "_annotations_data"): return current_colour = self._annotations_data["colours"][row] rgb = current_colour[:3] if len(current_colour) >= 3 else [0, 0, 0] color = QColorDialog.getColor( QColor(int(rgb[0]), int(rgb[1]), int(rgb[2])), self, "Choose annotation colour", ) if color.isValid(): r, g, b, _ = color.getRgb() self._annotations_data["colours"][row] = [r, g, b] # Update button color color_widget = self.annotations_table.cellWidget(row, 2) if color_widget is not None: color_btn = color_widget.findChild(QPushButton) if color_btn is not None: color_btn.setStyleSheet( f"background-color: rgb({r}, {g}, {b}); " f"border: 1px solid {PALETTE['border']};" ) def _on_annotations_save(self): """Save annotation changes (classification + colour) back to sdata. Colours are stored as [r,g,b] integer lists and classification is preserved as a categorical column where possible. """ try: self._collect_pending_geometries() if "Annotations" not in self.sdata.shapes: show_info("No Annotations layer to save.") return if not hasattr(self, "_annotations_data"): show_info("No annotation data loaded.") return annotations = self.sdata.shapes["Annotations"] layer = _find_shapes_layer(self.viewer, "Annotations") if hasattr(self, "_pending_annotation_base_df") and self._pending_annotation_base_df is not None: data = self._pending_annotation_base_df.copy() else: data = _shape_to_dataframe(annotations) n_rows = self.annotations_table.rowCount() updated_ids = [ self._annotations_data["ids"][row] if row < len(self._annotations_data.get("ids", [])) else str(uuid.uuid4()) for row in range(n_rows) ] updated_classifications = [] updated_colours_rgb = [] for row in range(n_rows): item = self.annotations_table.item(row, 1) if item is not None: updated_classifications.append(item.text()) elif row < len(self._annotations_data["classifications"]): updated_classifications.append(self._annotations_data["classifications"][row]) else: updated_classifications.append("new_annotation") if row < len(self._annotations_data["colours"]): rgb = self._annotations_data["colours"][row] else: rgb = [128, 128, 128] updated_colours_rgb.append([int(rgb[0]), int(rgb[1]), int(rgb[2])]) if data is not None: if len(data) != n_rows: if len(data) < n_rows: n_new = n_rows - len(data) new_geoms = [] if hasattr(self, "_pending_annotation_base_df"): new_geoms = self._pending_annotation_base_df["geometry"].iloc[len(data):len(data) + n_new].tolist() if len(new_geoms) < n_new or any(g is None for g in new_geoms): self._collect_pending_geometries() if hasattr(self, "_pending_annotation_base_df"): new_geoms = self._pending_annotation_base_df["geometry"].iloc[len(data):len(data) + n_new].tolist() if len(new_geoms) < n_new: show_info( "Some new annotation rows have no drawn shape yet. " "Draw the new shapes in napari, then save again." ) new_geoms = new_geoms + [None] * (n_new - len(new_geoms)) import geopandas as gpd extra = gpd.GeoDataFrame( {"geometry": new_geoms}, crs=getattr(data, "crs", None) ) data = pd.concat([data, extra], ignore_index=True) else: data = data.iloc[:n_rows].reset_index(drop=True) data = data.copy() categories = list(dict.fromkeys(updated_classifications)) if "classification" in data.columns and pd.api.types.is_categorical_dtype(data["classification"]): for label in data["classification"].cat.categories: if label not in categories: categories.append(label) data["classification"] = pd.Categorical( updated_classifications, categories=categories, ) data["colour"] = [list(rgb) for rgb in updated_colours_rgb] data["id"] = updated_ids if "objectType" in data.columns: obj_types = list(data["objectType"].astype(str).tolist()) obj_types = obj_types[:len(data)] + ["annotation"] * max(0, n_rows - len(obj_types)) else: obj_types = ["annotation"] * n_rows data["objectType"] = obj_types try: self.sdata.shapes["Annotations"] = data except Exception: try: annotations["classification"] = pd.Categorical( updated_classifications, categories=categories, ) annotations["colour"] = [list(rgb) for rgb in updated_colours_rgb] annotations["id"] = updated_ids annotations["objectType"] = obj_types except Exception as e: show_info(f"Could not write back annotations table: {e}") return else: if not hasattr(annotations, "attrs"): annotations.attrs = {} annotations.attrs["classification"] = updated_classifications annotations.attrs["colour"] = [list(rgb) for rgb in updated_colours_rgb] annotations.attrs["id"] = updated_ids annotations.attrs["objectType"] = ["annotation"] * n_rows self._annotations_data["ids"] = updated_ids self._annotations_data["classifications"] = updated_classifications self._annotations_data["colours"] = updated_colours_rgb from napari.layers import Shapes for layer in self.viewer.layers: if isinstance(layer, Shapes) and layer.name == "Annotations": if updated_colours_rgb: _apply_direct_shapes_colors(layer, updated_colours_rgb) break show_info("Annotations saved successfully.") self._pending_annotation_base_df = None self._pending_annotation_draw_start = None self._pending_annotation_original_count = None self._pending_annotations_registered = False if hasattr(self, "meta_shapes_combo") and self.meta_shapes_combo.currentText() == "Annotations": self._populate_meta_columns(self.meta_col_combo, "Annotations") except Exception as e: show_info(f"Error saving annotations: {e}") def _on_pick_new_annotation_colour(self): rgb = self._new_annotation_colour color = QColorDialog.getColor( QColor(int(rgb[0]), int(rgb[1]), int(rgb[2])), self, "Choose colour for new annotation", ) if color.isValid(): r, g, b, _ = color.getRgb() self._new_annotation_colour = [r, g, b] self._new_annotation_colour_btn.setStyleSheet( f"background-color: rgb({r}, {g}, {b}); " f"border: 1px solid {PALETTE['border']};" ) def _on_add_annotation(self): """ Append a pending annotation row with label and colour, and reserve a placeholder in the stored annotations GeoDataFrame. The actual geometry is registered later from the live napari layer. """ if not hasattr(self, "_annotations_data"): self._annotations_data = {"ids": [], "classifications": [], "colours": []} label = self.new_annotation_name.toPlainText().strip() if not label: show_info("Enter a label for the new annotation.") return rgb = list(self._new_annotation_colour) layer = _find_shapes_layer(self.viewer, "Annotations") layer_count = len(layer.data) if layer is not None else 0 if getattr(self, "_pending_annotation_draw_start", None) is None: self._pending_annotation_draw_start = layer_count annotations = self.sdata.shapes.get("Annotations", None) data = _shape_to_dataframe(annotations) if data is None: try: import geopandas as gpd data = gpd.GeoDataFrame( { "id": pd.Series(dtype=str), "objectType": pd.Series(dtype=str), "classification": pd.Series(dtype=str), "colour": pd.Series(dtype=object), }, geometry=pd.Series(dtype="geometry"), ) except Exception: data = pd.DataFrame( columns=["id", "objectType", "classification", "colour", "geometry"] ) if getattr(self, "_pending_annotation_original_count", None) is None: self._pending_annotation_original_count = len(data) columns = list(data.columns) for required in ["id", "objectType", "classification", "colour", "geometry"]: if required not in columns: columns.append(required) new_row = {col: None for col in columns} new_id = str(uuid.uuid4()) new_row.update({ "id": new_id, "objectType": "annotation", "classification": label, "colour": [int(rgb[0]), int(rgb[1]), int(rgb[2])], "geometry": None, }) new_row_df = pd.DataFrame([new_row], columns=columns) if hasattr(data, "geometry") and "geometry" in data.columns: try: import geopandas as gpd new_row_df = gpd.GeoDataFrame(new_row_df, geometry="geometry", crs=getattr(data, "crs", None)) except Exception: pass self._pending_annotation_base_df = pd.concat([data, new_row_df], ignore_index=True) new_row_index = self.annotations_table.rowCount() self.annotations_table.insertRow(new_row_index) self._annotations_data["ids"].append(new_id) self._annotations_data["classifications"].append(label) self._annotations_data["colours"].append(rgb) id_item = QTableWidgetItem(new_id) id_item.setFlags(id_item.flags() & ~Qt.ItemIsEditable) self.annotations_table.setItem(new_row_index, 0, id_item) label_item = QTableWidgetItem(label) self.annotations_table.setItem(new_row_index, 1, label_item) color_widget = QWidget() color_layout = QHBoxLayout(color_widget) color_layout.setContentsMargins(2, 2, 2, 2) color_btn = QPushButton() color_btn.setMaximumWidth(60) color_btn.setStyleSheet( f"background-color: rgb({rgb[0]}, {rgb[1]}, {rgb[2]}); " f"border: 1px solid {PALETTE['border']}; border-radius: 3px;" ) color_btn.clicked.connect(lambda checked, r=new_row_index: self._on_color_picker(r)) color_layout.addWidget(color_btn) color_layout.addStretch() self.annotations_table.setCellWidget(new_row_index, 2, color_widget) self.new_annotation_name.clear() # Enable Step 2 for drawing if hasattr(self, "_step2_group"): self._step2_group.setEnabled(True) show_info( f"Added '{label}'. Now draw the corresponding shape on the " "'Annotations' layer in napari using the shape tools." )
[docs] def save_to_sdata( self, layers: list[Layer] | None = None, spatial_element_name: str | None = None, table_name: str | None = None, table_columns: list[str] | None = None, overwrite: bool = False, ) -> None: """ Add the current selected napari layer(s) to the SpatialData object. If the layer is newly added and not yet linked with a spatialdata object it will be automatically linked if only 1 spatialdata object is being visualized in the viewer. Notes ----- Usage: - you can invoke this function by pressing Shift+E; - the selected layer (needs to be exactly one) will be saved; - if more than one SpatialData object is being shown with napari, before saving the layer you need to link it to a layer with a SpatialData object. This can be done by selecting both layers and pressing Shift+L. - Currently images and labels are not supported. - Currently updating existing elements is not supported. """ selected_layers = layers if layers else self.viewer.layers.selection if len(selected_layers) != 1: raise ValueError("Only one layer can be saved at a time.") selected = list(selected_layers)[0] if "sdata" not in selected.metadata: sdatas = [(layer, layer.metadata["sdata"]) for layer in self.viewer.layers if "sdata" in layer.metadata] if len(sdatas) < 1: raise ValueError( "No SpatialData layers found in the viewer. Layer cannot be linked to SpatialData object." ) if len(sdatas) > 1 and not all(sdatas[0][1] is sdata[1] for sdata in sdatas[1:]): raise ValueError( "Multiple different spatialdata object found in the viewer. Please link the layer to " "one of them by selecting both the layer to save and the layer containing the SpatialData object " "and then pressing Shift+L. Then select the layer to save and press Shift+E again." ) # link the layer to the only sdata object self._inherit_metadata(self.viewer) assert selected.metadata["sdata"] # now we can save the layer since it is linked to a SpatialData object if isinstance(selected, Points): parsed, cs = self._save_points_to_sdata(selected, spatial_element_name, overwrite) elif isinstance(selected, Shapes): parsed, cs = self._save_shapes_to_sdata(selected, spatial_element_name, overwrite) if table_name: self._save_table_to_sdata(selected, table_name, spatial_element_name, table_columns, overwrite) elif isinstance(selected, Image | Labels): raise NotImplementedError else: raise ValueError(f"Layer of type {type(selected)} cannot be saved.") self.layer_names.add(selected.name) self._layer_event_caches[selected.name] = [] self._update_metadata(selected, parsed) selected.events.data.connect(self._update_cache_indices) selected.events.name.connect(self._validate_name) self.layer_saved.emit(cs) show_info("Layer saved")
def _on_register_annotation(self): if not hasattr(self, "_pending_annotation_base_df"): show_info("No pending annotations to register. Add an annotation first.") return layer = _find_shapes_layer(self.viewer, "Annotations") if layer is None: show_info("Annotations layer not found in napari.") return start = getattr(self, "_pending_annotation_draw_start", None) if start is None: show_info("No reference draw state found. Add an annotation first.") return total_new = len(self._pending_annotation_base_df) - getattr(self, "_pending_annotation_original_count", 0) if total_new <= 0: show_info("No pending new annotations to register.") return from shapely.geometry import Polygon geoms = [] for coords in layer.data[start:start + total_new]: try: coords_arr = np.asarray(coords, dtype=float) if coords_arr.ndim == 2 and coords_arr.shape[1] == 2: geoms.append(Polygon(coords_arr[:, ::-1])) else: geoms.append(None) except Exception: geoms.append(None) if len(geoms) != total_new: show_info( "Could not register all new geometries. Draw the new region(s) " "on the Annotations layer, then try Register Annotation again." ) return if "geometry" not in self._pending_annotation_base_df.columns: self._pending_annotation_base_df["geometry"] = None self._pending_annotation_base_df.loc[ self._pending_annotation_original_count:, "geometry" ] = geoms self._pending_annotations_registered = True sdata_viewer = self._get_sdata_viewer() if sdata_viewer is not None: try: # Select the layer (mirrors what the user does manually before Shift+E) self.viewer.layers.selection.active = layer self.viewer.layers.selection = {layer} # Ensure the layer is linked to the SpatialData object (Shift+L equivalent) if "sdata" not in layer.metadata: sdata_viewer._inherit_metadata(self.viewer) # Same call Shift+E triggers sdata_viewer.save_to_sdata([layer]) show_info("New annotation geometry registered and saved to SpatialData.") except Exception as e: show_info(f"Geometry registered, but save_to_sdata failed: {e}") else: show_info( "New annotation geometry registered. Could not find the napari-spatialdata " "viewer widget; press Shift+E manually to save." ) def _get_sdata_viewer(self): """Return the napari_spatialdata SpatialDataViewer instance, if available.""" interactive = _GOATPY_REFS.get("interactive") if interactive is None: return None # napari-spatialdata stores the SpatialDataViewer (QObject) on the dock widget for attr in ("_sdata_widget", "_widget"): widget = getattr(interactive, attr, None) if widget is not None: viewer_model = getattr(widget, "_viewer_model", None) or getattr(widget, "viewer_model", None) if viewer_model is not None: return viewer_model return None def _collect_pending_geometries(self): """Pull newly-drawn annotation geometries from the live 'Annotations' layer.""" if not hasattr(self, "_pending_annotation_base_df"): return layer = _find_shapes_layer(self.viewer, "Annotations") if layer is None: return start = getattr(self, "_pending_annotation_draw_start", None) if start is None: return total_new = len(self._pending_annotation_base_df) - getattr(self, "_pending_annotation_original_count", 0) if total_new <= 0: return from shapely.geometry import Polygon geoms = [] for coords in layer.data[start:start + total_new]: try: coords_arr = np.asarray(coords, dtype=float) if coords_arr.ndim == 2 and coords_arr.shape[1] == 2: geoms.append(Polygon(coords_arr[:, ::-1])) else: geoms.append(None) except Exception: geoms.append(None) if len(geoms) != total_new: return if "geometry" not in self._pending_annotation_base_df.columns: self._pending_annotation_base_df["geometry"] = None self._pending_annotation_base_df.loc[ self._pending_annotation_original_count:, "geometry" ] = geoms self._pending_annotations_registered = True # ── Stats tab ───────────────────────────────────────────────────────── def _build_stats_tab(self) -> QWidget: w = QWidget() layout = QVBoxLayout(w) layout.setContentsMargins(4, 4, 4, 4) plot_btn = QPushButton("Refresh Statistics") plot_btn.clicked.connect(self._draw_stats) layout.addWidget(plot_btn) self.stats_canvas = MplCanvas(w, width=4.5, height=5.5, dpi=90) layout.addWidget(self.stats_canvas) QTimer.singleShot(400, self._draw_stats) return w def _draw_stats(self): adata = self._adata() X = np.asarray(adata.X, dtype=np.float32) self.stats_canvas.fig.clear() axes = self.stats_canvas.fig.subplots(2, 2) for ax in axes.flat: self.stats_canvas._style_ax(ax) axes[0, 0].hist(X.sum(axis=1), bins=50, color=PALETTE["accent"], edgecolor="none", alpha=0.8) axes[0, 0].set_title("TIC distribution", fontsize=8) axes[0, 1].hist((X > 0).sum(axis=1), bins=40, color=PALETTE["accent2"], edgecolor="none", alpha=0.8) axes[0, 1].set_title("Peaks / pixel", fontsize=8) axes[1, 0].hist((X > 0).mean(axis=0) * 100, bins=40, color=PALETTE["success"], edgecolor="none", alpha=0.8) axes[1, 0].set_title("Peak frequency (%)", fontsize=8) cluster_drawn = False for col in ("GPCA_clusters", "leiden", "batch", "annotation"): if col in adata.obs.columns: cats = adata.obs[col].astype(str) counts = cats.value_counts().sort_index() cmap = plt.get_cmap("tab20", len(counts)) axes[1, 1].bar(range(len(counts)), counts.values, color=[cmap(i) for i in range(len(counts))], edgecolor="none", alpha=0.85) axes[1, 1].set_xticks(range(len(counts))) axes[1, 1].set_xticklabels(counts.index, rotation=45, ha="right", fontsize=6) axes[1, 1].set_title(f"Cluster sizes ({col})", fontsize=8) cluster_drawn = True break if not cluster_drawn: axes[1, 1].text(0.5, 0.5, "No cluster\ncolumn found", ha="center", va="center", transform=axes[1, 1].transAxes, color=PALETTE["text_dim"]) for ax in axes.flat: ax.tick_params(labelsize=6.5) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) self.stats_canvas.fig.suptitle( f"{adata.n_obs:,} pixels × {adata.n_vars:,} peaks", fontsize=8.5, color=PALETTE["text"], y=1.01, ) self.stats_canvas.fig.tight_layout(pad=0.5) self.stats_canvas.draw()
# ════════════════════════════════════════════════════════════════════════════ # Post-launch window setup # ════════════════════════════════════════════════════════════════════════════ def _configure_viewer_window(viewer, screen_height: int = 1080): from qtpy.QtCore import QTimer def _setup(): win = viewer.window._qt_window # ── 1. Resize window ────────────────────────────────────────────── h = int(screen_height * 0.90) w = int(h * 16 / 9) win.resize(w, h) # ── 2. Hide napari-spatialdata docks we don't need ──────────────── from qtpy.QtWidgets import QDockWidget _hidden_docks = {"View (napari-spatialdata)", "slider", "colorbar"} for qdock in win.findChildren(QDockWidget): if qdock.windowTitle() in _hidden_docks: qdock.hide() # ── 3. Fix vertical splitter so spectrum dock is visible ────────── from qtpy.QtWidgets import QSplitter from qtpy.QtCore import Qt for sp in win.findChildren(QSplitter): if sp.orientation() == Qt.Vertical: sizes = sp.sizes() total = sum(sizes) if total > 100 and len(sizes) >= 2: bottom = int(total * 0.28) top = total - bottom new_sizes = [top] + [bottom] * (len(sizes) - 1) sp.setSizes(new_sizes) break QTimer.singleShot(500, _setup) # ════════════════════════════════════════════════════════════════════════════ # 3. MAIN LAUNCH FUNCTION # ════════════════════════════════════════════════════════════════════════════
[docs] def launch_goatpy_gui( sdata: SpatialData, peaks: Optional[list[float]] = None, glycan_csv: Optional[str] = None, table_name: str = "maldi_adata", viewer: Optional[napari.Viewer] = None, applied_tolerance: float = 0.1, ) -> napari.Viewer: if peaks is None: try: peaks = list(_resolve_mz_array(sdata.tables[table_name])) except Exception: peaks = [] glycan_df = None if glycan_csv: try: raw = pd.read_csv(glycan_csv) raw.columns = [c.strip() for c in raw.columns] glycan_df = raw except Exception as e: print(f"[goatpy GUI] glycan CSV error: {e}") # ── Load layers via napari-spatialdata (handles aligned CS correctly) ── interactive = _add_spatialdata_layers( viewer=None, # not used anymore — Interactive owns its viewer sdata=sdata, target_cs="aligned", ) # Get the actual napari Viewer from Interactive viewer = interactive._viewer # napari_spatialdata stores it here viewer.title = "goatpy — Spatial Glycomics Analysis" # ── Keep Interactive alive (don't let it be GC'd) ───────────────────── _GOATPY_REFS["interactive"] = interactive # ── Spectrum widget ──────────────────────────────────────────────────── spectrum_widget = SpectrumWidget( sdata=sdata, peaks=peaks, glycan_df=glycan_df, table_name=table_name, applied_tolerance=applied_tolerance, ) viewer.window.add_dock_widget(spectrum_widget, area="bottom", name="Spectrum") # ── Sidebar widget ───────────────────────────────────────────────────── sidebar = AnalysisSidebar( sdata=sdata, peaks=peaks, viewer=viewer, glycan_df=glycan_df, table_name=table_name, ) viewer.window.add_dock_widget(sidebar, area="right", name="Analysis") # ── Wiring ───────────────────────────────────────────────────────────── sidebar.glycan_selected.connect( lambda mz, lbl: spectrum_widget.highlight_glycan(mz, lbl) ) spectrum_widget.peak_clicked.connect(sidebar.select_peak_from_spectrum) spectrum_widget.unregistered_peak_display.connect(sidebar.display_unregistered_ion_image) print( f"[goatpy GUI] Ready\n" f"{len(sdata.tables[table_name])} pixels · {len(peaks)} peaks" ) _configure_viewer_window(viewer) return viewer
# ════════════════════════════════════════════════════════════════════════════ # Standalone test with dummy data # ════════════════════════════════════════════════════════════════════════════ def _make_dummy_sdata() -> SpatialData: import anndata as ad import geopandas as gpd from shapely.geometry import box from spatialdata.models import TableModel, ShapesModel, PointsModel from spatialdata.transformations import Identity np.random.seed(42) n, n_peaks = 400, 40 peaks_mz = np.sort(np.random.uniform(900, 2800, n_peaks)) X = np.abs(np.random.randn(n, n_peaks)).astype(np.float32) he_x = np.random.uniform(0, 1000, n) he_y = np.random.uniform(0, 800, n) obs = pd.DataFrame({ "he_x": he_x, "he_y": he_y, "GPCA_clusters": pd.Categorical(np.random.choice(["0","1","2","3"], n)), "MPI": X.sum(axis=1), }) adata = ad.AnnData(X=X, obs=obs) adata.var_names = [f"{m:.4f}" for m in peaks_mz] adata.obsm["spatial"] = obs[["he_x", "he_y"]].values adata.obsm["GraphPCA"] = np.random.randn(n, 10).astype(np.float32) geoms = [box(x-2, y-2, x+2, y+2) for x, y in zip(he_x, he_y)] gdf = gpd.GeoDataFrame({"cell_id": np.arange(n).astype(str)}, geometry=geoms) shapes = ShapesModel.parse(gdf, transformations={"global": Identity()}) pts = pd.DataFrame({"x": he_x, "y": he_y, "cell_id": np.arange(n).astype(str)}) centroids = PointsModel.parse(pts) sdata = SpatialData(shapes={"pixels": shapes}, points={"centroids": centroids}) adata.obs["instance_id"] = gdf.index adata.obs["region"] = pd.Categorical(["pixels"] * n) table = TableModel.parse(adata, region="pixels", region_key="region", instance_key="instance_id") sdata["maldi_adata"] = table # Add a dummy H&E image (1000×800 pink noise) from spatialdata.models import Image2DModel he = np.random.randint(200, 255, (3, 800, 1000), dtype=np.uint8) sdata["he_image"] = Image2DModel.parse( he, dims=("c","y","x"), transformations={"global": Identity()} ) return sdata