Source code for goatpy.auto_align

"""
auto_align.py
=============
Auto-registration of a MALDI imzML dataset against a whole-slide H&E image.

Key design decisions
--------------------
1.  Registration is performed at MALDI native pixel size (default 10 um/px)
    so the H&E thumbnail is tiny and peak RAM stays manageable.

2.  Matching uses normalised grayscale cross-correlation (TM_CCOEFF_NORMED)
    on the TIC image vs inverted H&E intensity.

3.  A two-pass rotation search (coarse 0-360, then fine) finds the correct
    slide orientation without assuming any starting angle.

4.  Registration, canvas building, and annotation transforms all use
    cv2.warpAffine with the same analytically derived affine matrix.
    This guarantees that best_idx, canvas pixels, and annotation coordinates
    all live in the same coordinate system.

Coordinate system
-----------------
All three steps share the same affine matrix M_stored:

    M_stored maps: reg-resolution H&E coords -> canvas coords

    _match_at_rotation  uses cv2.warpAffine(M_stored) to build the search canvas
    _build_affine_and_canvas uses cv2.warpAffine(M_stored) to build output canvas
    _transform_geojson  applies M_up @ M_stored @ M_scale to annotation vertices

Because the same matrix is used everywhere, best_idx from matchTemplate is
directly valid as the MALDI placement offset in the output canvas.
"""

from __future__ import annotations

import os
import gc
import json
from functools import partial
from pathlib import Path
from typing import Optional, Union

import numpy as np
import pandas as pd
import anndata as ad
import geopandas as gpd
import cv2 as cv
from PIL import Image
from shapely.geometry import box, shape
from shapely import transform as shapely_transform

from spatialdata import SpatialData
from spatialdata.models import Image2DModel, PointsModel, ShapesModel, TableModel
from spatialdata.transformations import Identity

from .io import parmap, getimage, rd_peaks, rd_peaks_from_package
from pyimzml.ImzMLParser import ImzMLParser


# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------

def _log(msg: str) -> None:
    try:
        import psutil
        rss = psutil.Process(os.getpid()).memory_info().rss / 1e9
        print(f"[{rss:.2f}GB] {msg}")
    except ImportError:
        print(msg)


# ---------------------------------------------------------------------------
# H&E loading
# ---------------------------------------------------------------------------

def _read_native_mpp(he_path: str) -> Optional[float]:
    """Read native microns-per-pixel from file metadata without loading pixels."""
    try:
        import openslide
        slide = openslide.OpenSlide(he_path)
        mpp_x = slide.properties.get(openslide.PROPERTY_NAME_MPP_X)
        mpp_y = slide.properties.get(openslide.PROPERTY_NAME_MPP_Y)
        slide.close()
        if mpp_x and mpp_y:
            return (float(mpp_x) + float(mpp_y)) / 2.0
    except Exception:
        pass
    return None


def _load_he_at_resolution(he_path: str,
                            target_mpp: float,
                            native_mpp: float) -> tuple[Image.Image, float]:
    """
    Load H&E at (or just finer than) target_mpp using openslide pyramid
    selection, then resize to exactly target_mpp.
    """
    ext = os.path.splitext(he_path)[1].lower()
    wsi_exts = {'.svs', '.ndpi', '.scn', '.czi', '.mrxs'}

    try:
        import openslide
        slide = openslide.OpenSlide(he_path)

        best_level = 0
        best_mpp   = native_mpp
        for lvl in range(slide.level_count):
            lvl_mpp = native_mpp * slide.level_downsamples[lvl]
            if lvl_mpp <= target_mpp * 1.05:
                best_level = lvl
                best_mpp   = lvl_mpp

        dims   = slide.level_dimensions[best_level]
        region = slide.read_region((0, 0), best_level, dims)
        img    = region.convert('RGB')
        slide.close()
        _log(f"  openslide level {best_level}: {dims[0]}x{dims[1]}  "
             f"{best_mpp:.3f} um/px  ({img.width*img.height*3/1e6:.0f} MB)")

        if abs(best_mpp - target_mpp) / target_mpp > 0.02:
            scale = best_mpp / target_mpp
            nw    = max(1, round(img.width  * scale))
            nh    = max(1, round(img.height * scale))
            img   = img.resize((nw, nh), Image.Resampling.LANCZOS)
            best_mpp = target_mpp
            _log(f"  Resized to {nw}x{nh}  {target_mpp:.3f} um/px")

        return img, best_mpp

    except Exception as e:
        if ext in wsi_exts:
            raise RuntimeError(
                f"\nFailed to open '{he_path}' with openslide: {e}\n\n"
                "SVS/NDPI files require openslide:\n"
                "  conda install -c conda-forge openslide openslide-python\n"
            ) from e

    img = Image.open(he_path).convert('RGB')
    _log(f"  PIL: {img.width}x{img.height}")
    scale = native_mpp / target_mpp
    nw    = max(1, round(img.width  * scale))
    nh    = max(1, round(img.height * scale))
    img   = img.resize((nw, nh), Image.Resampling.LANCZOS)
    _log(f"  Resized to {nw}x{nh}  {target_mpp:.3f} um/px  ({nw*nh*3/1e6:.0f} MB)")
    return img, target_mpp


# ---------------------------------------------------------------------------
# MALDI loading
# ---------------------------------------------------------------------------

def _load_spectra(imzml_path: str,
                  peaks: list[float],
                  chunk_size: int = 10,
                  crop_r: int = 0,
                  crop_c: int = 0) -> np.ndarray:
    from pyimzml.ImzMLParser import getionimage
    p0    = ImzMLParser(imzml_path)
    probe = getionimage(p0, peaks[0], tol=0.1, reduce_func=max)
    h     = probe.shape[0] - crop_r
    w     = probe.shape[1] - crop_c
    del probe

    out = np.zeros((h, w, len(peaks)), dtype=np.float32)
    for start in range(0, len(peaks), chunk_size):
        batch = peaks[start: start + chunk_size]
        imgs  = parmap(partial(getimage, path=imzml_path), batch,
                       nprocs=min(len(batch), 4))
        for j, img in enumerate(imgs):
            out[:, :, start + j] = img[crop_r:, crop_c:]
        del imgs
        _log(f"  Peaks {start+1}-{min(start+len(batch), len(peaks))} / {len(peaks)}")
    return out


def _read_maldi_pixel_size(imzml_path: str) -> Optional[float]:
    try:
        p = ImzMLParser(imzml_path)
        for key in ['pixel size (x)', 'pixel size x', 'pixel size']:
            val = p.imzmldict.get(key)
            if val is not None:
                return float(val)
    except Exception:
        pass
    return None


def _crop_offsets(spectra_sum: np.ndarray, cutoff: float = 0.5) -> tuple[int, int]:
    try:
        crop_c = int(max(np.where(np.sum(spectra_sum, axis=0) < cutoff)[0]))
        crop_r = int(max(np.where(np.sum(spectra_sum, axis=1) < cutoff)[0]))
        return crop_r, crop_c
    except (ValueError, IndexError):
        return 0, 0


# ---------------------------------------------------------------------------
# Image preparation
# ---------------------------------------------------------------------------

def _maldi_to_grayscale(tic: np.ndarray) -> np.ndarray:
    blurred = cv.GaussianBlur(tic, (3, 3), 0)
    mn, mx  = blurred.min(), blurred.max()
    norm    = (blurred - mn) / (mx - mn) if mx > mn else blurred * 0.0
    _log(f"  MALDI grayscale: {norm.shape}  mean={norm.mean():.3f}")
    return norm.astype(np.float32)


def _he_to_grayscale(he_img: Image.Image) -> np.ndarray:
    gray = np.array(he_img.convert('L'), dtype=np.float32)
    inv  = 255.0 - gray
    mn, mx = inv.min(), inv.max()
    norm   = (inv - mn) / (mx - mn) if mx > mn else inv * 0.0
    _log(f"  H&E grayscale: {norm.shape}  mean={norm.mean():.3f}")
    return norm.astype(np.float32)


# ---------------------------------------------------------------------------
# Affine matrix construction (shared by registration and canvas building)
# ---------------------------------------------------------------------------

def _build_affine_matrix(
    src_w: int,
    src_h: int,
    rotation_deg: float,
    buffer_px: int,
    min_w: int = 0,
    min_h: int = 0,
) -> tuple[np.ndarray, int, int, int, int]:
    """
    Compute the affine matrix that:
        1. Rotates CCW around the image centre ((src_w-1)/2, (src_h-1)/2)
        2. Shifts so the rotated bounding box starts at (0, 0)
        3. Centres the result in a buffer canvas, which is guaranteed to be
           at least (min_w, min_h) pixels so matchTemplate never fails.

    This function is called by both _match_at_rotation (registration) and
    _build_affine_and_canvas (output), so both steps share the identical
    coordinate system and best_idx is directly valid in the output canvas.

    Parameters
    ----------
    src_w, src_h  : H&E image dimensions at registration resolution
    rotation_deg  : CCW rotation in degrees
    buffer_px     : symmetric padding added around the rotated image
    min_w, min_h  : minimum canvas dimensions (set to MALDI template size
                    during registration to guarantee matchTemplate succeeds
                    at all rotation angles including 90/270 degree flips)

    Returns
    -------
    M_stored  : np.ndarray (3, 3) -- reg-res H&E coords -> canvas coords
    canvas_w  : int
    canvas_h  : int
    canvas_pc : int -- col offset
    canvas_pr : int -- row offset
    """
    theta = np.deg2rad(rotation_deg)
    cos_t = np.cos(theta)
    sin_t = np.sin(theta)

    cx = (src_w - 1) / 2.0
    cy = (src_h - 1) / 2.0

    corners = np.array([
        [0.0,       0.0      ],
        [src_w - 1, 0.0      ],
        [src_w - 1, src_h - 1],
        [0.0,       src_h - 1],
    ], dtype=np.float64)

    dx = corners[:, 0] - cx
    dy = corners[:, 1] - cy
    rot_x = cos_t * dx - sin_t * dy + cx
    rot_y = sin_t * dx + cos_t * dy + cy

    expand_x = rot_x.min()
    expand_y = rot_y.min()

    rot_w = int(np.ceil(rot_x.max() - rot_x.min()))
    rot_h = int(np.ceil(rot_y.max() - rot_y.min()))

    # Guarantee the canvas is large enough for the MALDI template at all
    # rotation angles. At 90/270 degrees the rotated H&E bounding box swaps
    # width and height, which can make the canvas narrower than the template.
    # We expand symmetrically so the affine translation stays centred.
    canvas_w  = max(rot_w + buffer_px, min_w)
    canvas_h  = max(rot_h + buffer_px, min_h)
    canvas_pc = (canvas_w - rot_w) // 2
    canvas_pr = (canvas_h - rot_h) // 2

    tx = -cos_t * cx + sin_t * cy + cx - expand_x + canvas_pc
    ty = -sin_t * cx - cos_t * cy + cy - expand_y + canvas_pr

    M_stored = np.array([
        [ cos_t, -sin_t, tx],
        [ sin_t,  cos_t, ty],
        [   0.0,    0.0, 1.0],
    ], dtype=np.float64)

    return M_stored, canvas_w, canvas_h, canvas_pc, canvas_pr


# ---------------------------------------------------------------------------
# Registration — now using cv2 so coordinate system matches canvas building
# ---------------------------------------------------------------------------

def _match_at_rotation(
    he_gray: np.ndarray,
    maldi_gray: np.ndarray,
    rotation: float,
    buffer_px: int,
) -> tuple[float, tuple[int, int]]:
    """
    Rotate the H&E grayscale image using the same cv2.warpAffine convention
    as _build_affine_and_canvas, then run matchTemplate.

    Because the rotation is applied identically here and in the output canvas
    builder, best_idx returned here is directly valid as the MALDI placement
    offset in the final canvas — no coordinate remapping needed.

    Parameters
    ----------
    he_gray    : float32 grayscale H&E at reg resolution
    maldi_gray : float32 grayscale MALDI TIC template
    rotation   : CCW rotation in degrees
    buffer_px  : canvas padding — must match _build_affine_and_canvas

    Returns
    -------
    score  : float, normalised cross-correlation score
    loc    : (row, col) top-left of best template match in canvas coords
    """
    src_h, src_w = he_gray.shape
    tmpl_h, tmpl_w = maldi_gray.shape

    # Pass template dimensions as minimum canvas size so the canvas is always
    # large enough for matchTemplate at every rotation angle, including 90/270
    # degree cases where the rotated H&E bounding box swaps width and height.
    M_stored, canvas_w, canvas_h, _, _ = _build_affine_matrix(
        src_w, src_h, rotation, buffer_px,
        min_w=tmpl_w, min_h=tmpl_h,
    )

    M_cv = M_stored[:2, :]
    canvas = cv.warpAffine(
        he_gray,
        M_cv,
        (canvas_w, canvas_h),
        flags=cv.INTER_LINEAR,
        borderValue=0.0,
    )

    result           = cv.matchTemplate(canvas, maldi_gray, cv.TM_CCOEFF_NORMED)
    _, score, _, loc = cv.minMaxLoc(result)
    return float(score), (int(loc[1]), int(loc[0]))   # (row, col)


def _register(
    he_gray: np.ndarray,
    maldi_gray: np.ndarray,
    src_w: int,
    src_h: int,
    coarse_step: int = 15,
    fine_range: float = 5.0,
    fine_step: float = 1.0,
    buffer_px: int = 150,
) -> tuple[float, tuple[int, int]]:
    """
    Two-pass rotation + translation search using cv2-based matching.

    src_w and src_h are passed explicitly (they equal he_gray.shape[1] and
    he_gray.shape[0]) so that _build_affine_matrix can be called with the
    same values used later by _build_affine_and_canvas.
    """
    coarse_rots = list(range(0, 360, coarse_step))
    _log(f"  Coarse: {len(coarse_rots)} rotations (0-360 step {coarse_step}) ...")

    best_score = -np.inf
    best_rot   = 0.0
    best_idx   = (0, 0)

    for rot in coarse_rots:
        score, idx = _match_at_rotation(he_gray, maldi_gray, rot, buffer_px)
        _log(f"    {rot:5.1f}  score={score:.4f}")
        if score > best_score:
            best_score, best_rot, best_idx = score, float(rot), idx

    _log(f"  Best coarse: {best_rot}  score={best_score:.4f}")

    fine_rots = sorted({
        round(best_rot + d, 1)
        for d in np.arange(-fine_range, fine_range + fine_step, fine_step)
        if abs(d) > 1e-6
    })
    _log(f"  Fine: {len(fine_rots)} rotations (+-{fine_range} step {fine_step}) ...")

    for rot in fine_rots:
        score, idx = _match_at_rotation(he_gray, maldi_gray, rot, buffer_px)
        _log(f"    {rot:5.1f}  score={score:.4f}")
        if score > best_score:
            best_score, best_rot, best_idx = score, rot, idx

    _log(f"  Final: {best_rot}  score={best_score:.4f}  offset={best_idx}")
    return best_rot, best_idx


# ---------------------------------------------------------------------------
# Build H&E output canvas
# ---------------------------------------------------------------------------

def _build_affine_and_canvas(
    he_img: Image.Image,
    src_w: int,
    src_h: int,
    rotation_deg: float,
    buffer_px: int,
) -> tuple[np.ndarray, np.ndarray, int, int]:
    """
    Build the H&E output canvas using cv2.warpAffine.

    Calls _build_affine_matrix with the identical arguments used in
    _match_at_rotation, so the canvas geometry here is byte-for-byte the
    same as the search canvas used during registration.

    Returns
    -------
    canvas    : np.ndarray (H, W, 3) uint8
    M_stored  : np.ndarray (3, 3) float64 -- reg-res coords -> canvas coords
    canvas_pr : int
    canvas_pc : int
    """
    M_stored, canvas_w, canvas_h, canvas_pc, canvas_pr = _build_affine_matrix(
        src_w, src_h, rotation_deg, buffer_px
    )

    img_np = np.array(he_img, dtype=np.uint8)
    M_cv   = M_stored[:2, :]
    canvas = cv.warpAffine(
        img_np,
        M_cv,
        (canvas_w, canvas_h),
        flags=cv.INTER_LINEAR,
        borderValue=(0, 0, 0),
    )

    _log(f"  H&E canvas (cv2): {canvas_w}x{canvas_h}  "
         f"pr={canvas_pr}, pc={canvas_pc}  rotation={rotation_deg}")

    return canvas, M_stored, canvas_pr, canvas_pc


# ---------------------------------------------------------------------------
# Annotation transform
# ---------------------------------------------------------------------------

def _transform_geojson(
    geojson_path: Union[str, Path],
    he_pixel_um: float,
    reg_mpp: float,
    M_stored: np.ndarray,
    img_upscaling: int,
    classification_key: str = "classification",
) -> gpd.GeoDataFrame:
    """
    Transform QuPath GeoJSON annotations (native H&E pixel coords) into the
    final upscaled canvas coordinate system.

    Full pipeline:
        M_annotations = M_up @ M_stored @ M_scale

    where:
        M_scale  : native H&E pixels -> reg-resolution pixels
        M_stored : reg-res coords -> canvas coords  (from _build_affine_matrix)
        M_up     : canvas coords -> upscaled canvas coords
    """
    with open(geojson_path, "r") as f:
        geojson = json.load(f)

    features = geojson if isinstance(geojson, list) else geojson.get("features", [])
    if not features:
        raise ValueError(f"No features found in {geojson_path}")

    scale_to_reg = he_pixel_um / reg_mpp
    us           = float(img_upscaling)

    M_scale = np.array([
        [scale_to_reg, 0.0,          0.0],
        [0.0,          scale_to_reg, 0.0],
        [0.0,          0.0,          1.0],
    ], dtype=np.float64)

    M_up = np.array([
        [us,  0.0, 0.0],
        [0.0, us,  0.0],
        [0.0, 0.0, 1.0],
    ], dtype=np.float64)

    M = M_up @ M_stored @ M_scale

    a, b, tx = M[0, 0], M[0, 1], M[0, 2]
    d, e, ty = M[1, 0], M[1, 1], M[1, 2]

    def _apply(coords: np.ndarray) -> np.ndarray:
        x, y = coords[:, 0], coords[:, 1]
        return np.column_stack([
            a * x + b * y + tx,
            d * x + e * y + ty,
        ])

    geoms, labels, names = [], [], []
    for feat in features:
        geom_raw = feat.get("geometry")
        if geom_raw is None:
            continue
        geoms.append(shapely_transform(shape(geom_raw), _apply))
        props  = feat.get("properties") or {}
        clf    = props.get("classification") or {}
        labels.append(clf.get("name", "unknown") if isinstance(clf, dict) else str(clf))
        names.append(props.get("name", ""))

    return gpd.GeoDataFrame(
        {classification_key: labels, "name": names},
        geometry=geoms,
    )


# ---------------------------------------------------------------------------
# SpatialData construction
# ---------------------------------------------------------------------------

def _build_spatialdata(spectra_all: np.ndarray,
                       peaks: list[float],
                       maldi_pixel_um: float,
                       he_canvas: np.ndarray,
                       maldi_offset_in_canvas: tuple[int, int],
                       reg_mpp: float,
                       crop_r: int,
                       crop_c: int,
                       img_upscaling: int = 10,
                       library_id: str = "spatial") -> SpatialData:

    maldi_h, maldi_w, n_peaks = spectra_all.shape
    scale         = maldi_pixel_um / reg_mpp
    local_off_r, local_off_c = maldi_offset_in_canvas

    us      = img_upscaling
    he_up_h = he_canvas.shape[0] * us
    he_up_w = he_canvas.shape[1] * us
    he_up   = np.array(
        Image.fromarray(he_canvas).resize(
            (he_up_w, he_up_h), Image.Resampling.NEAREST
        ),
        dtype=np.uint8,
    )
    _log(f"  H&E upscaled {us}x: {he_up_w}x{he_up_h}  ({he_up.nbytes/1e6:.0f} MB)")

    grid_r, grid_c = np.mgrid[0: maldi_h, 0: maldi_w]
    he_r = ((local_off_r + (grid_r.flatten() + 0.5) * scale) * us)
    he_c = ((local_off_c + (grid_c.flatten() + 0.5) * scale) * us)

    adata = ad.AnnData(spectra_all.reshape(-1, n_peaks).copy(), dtype=np.float32)
    adata.var_names = np.array(["%.1f" % pk for pk in peaks])
    adata.obs_names = np.array([str(i) for i in range(maldi_h * maldi_w)])

    yy, xx = np.mgrid[crop_r: maldi_h + crop_r, crop_c: maldi_w + crop_c]
    adata.obs["x"]   = xx.flatten()
    adata.obs["y"]   = yy.flatten()
    adata.obs["MPI"] = np.ravel(adata.X.sum(axis=1))

    adata.obsm["spatial"] = np.column_stack([he_c, he_r])
    adata.obs["he_x"]     = he_c
    adata.obs["he_y"]     = he_r

    adata.uns["spatial"] = {
        library_id: {
            "images": {"hires": he_up},
            "use_quality": "hires",
            "scalefactors": {
                "tissue_hires_scalef": 1.0,
                "spot_diameter_fullres": float(us),
            },
        }
    }

    pixel_idx = np.arange(maldi_h * maldi_w).astype(str)
    half      = us / 2.0
    geoms     = [
        box(float(c) - half, float(r) - half,
            float(c) + half, float(r) + half)
        for r, c in zip(he_r, he_c)
    ]
    gdf    = gpd.GeoDataFrame({"cell_id": pixel_idx}, geometry=geoms)
    shapes = ShapesModel.parse(gdf, transformations={"global": Identity()})

    pts_df    = pd.DataFrame({"x": he_c, "y": he_r, "cell_id": pixel_idx})
    centroids = PointsModel.parse(pts_df)

    image_cyx = np.transpose(he_up, (2, 0, 1))
    img_model = Image2DModel.parse(
        image_cyx, dims=("c", "y", "x"),
        transformations={"global": Identity()},
    )

    sdata = SpatialData(
        images={"he_image": img_model},
        points={"centroids": centroids},
        shapes={"pixels": shapes},
    )

    adata.obs["instance_id"] = sdata["pixels"].index
    adata.obs["region"]      = "pixels"
    adata.obs["region"]      = adata.obs["region"].astype("category")

    table = TableModel.parse(
        adata, region="pixels",
        region_key="region", instance_key="instance_id",
    )
    sdata["maldi_adata"] = table
    return sdata


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

[docs] def load_and_align( imzml_path: str, he_path: str, peaks_path: Optional[str] = None, geojson_path: Optional[Union[str, Path]] = None, geojson_shapes_key: str = "annotations", geojson_classification_key: str = "classification", maldi_pixel_um: Optional[float] = None, he_pixel_um: Optional[float] = None, spectra_chunk_size: int = 10, coarse_rotation_step: int = 15, fine_rotation_range: float = 5.0, fine_rotation_step: float = 1.0, buffer_px: int = 150, img_upscaling: int = 10, ) -> SpatialData: """ Load a MALDI imzML dataset and an H&E image, auto-register them, and return a merged SpatialData object. All three coordinate-sensitive steps — registration search, canvas building, and annotation transform — use the same cv2.warpAffine matrix (_build_affine_matrix), so best_idx, canvas pixels, and annotation vertices are all guaranteed to be in the same coordinate system. Parameters ---------- imzml_path : str Path to the .imzML file. he_path : str Path to the H&E image. SVS/NDPI require openslide. peaks_path : str or None Path to peaks CSV. Uses bundled PEAKS.csv when None. geojson_path : str, Path, or None Optional path to a QuPath GeoJSON annotation export. geojson_shapes_key : str, default "annotations" Key under which annotations are stored in sdata.shapes. geojson_classification_key : str, default "classification" Column name for the QuPath class label in the GeoDataFrame. maldi_pixel_um : float or None Native MALDI pixel size in um. Auto-read from imzML when None. he_pixel_um : float or None Native H&E pixel size in um. Auto-read from metadata when None. spectra_chunk_size : int Ion images loaded in parallel at once. coarse_rotation_step : int Degrees between candidates in the 0-360 coarse sweep. fine_rotation_range : float +/- degrees searched around the best coarse angle. fine_rotation_step : float Degree increment for fine search. buffer_px : int Extra canvas padding (px at reg resolution). img_upscaling : int Each MALDI pixel is upscaled to img_upscaling x img_upscaling canvas pixels in the output. Returns ------- SpatialData with: images['he_image'] -- full rotated H&E canvas shapes['pixels'] -- one square per MALDI pixel shapes[geojson_shapes_key] -- annotations (if geojson_path given) points['centroids'] -- centroid of each MALDI pixel tables['maldi_adata'] -- AnnData with ion intensities """ # ------------------------------------------------------------------ # 0. Peaks # ------------------------------------------------------------------ _log("Loading peaks ...") peaks = rd_peaks(peaks_path) if peaks_path else rd_peaks_from_package() _log(f" {len(peaks)} peaks") # ------------------------------------------------------------------ # 0b. H&E native pixel size # ------------------------------------------------------------------ if he_pixel_um is None: he_pixel_um = _read_native_mpp(he_path) if he_pixel_um is None: try: _img = Image.open(he_path) tag_info = getattr(_img, 'tag_v2', {}) xres = tag_info.get(282) unit = tag_info.get(296, 2) if xres is not None: xres = xres[0] / xres[1] if isinstance(xres, tuple) else float(xres) he_pixel_um = (10000.0 / xres) if unit == 3 else (25400.0 / xres) _log(f" H&E pixel size from TIFF tags: {he_pixel_um:.4f} um/px") _img.close() except Exception: pass if he_pixel_um is None: he_pixel_um = 0.2527 _log(f" WARNING: H&E pixel size unknown, assuming {he_pixel_um} um/px.") else: _log(f" H&E native pixel size: {he_pixel_um:.4f} um/px") try: _he_probe = Image.open(he_path) _he_native_w, _he_native_h = _he_probe.size _he_probe.close() except Exception: _he_native_w, _he_native_h = 10000, 10000 he_phys_w_um = _he_native_w * he_pixel_um he_phys_h_um = _he_native_h * he_pixel_um # ------------------------------------------------------------------ # 0c. MALDI pixel size # ------------------------------------------------------------------ if maldi_pixel_um is None: detected = _read_maldi_pixel_size(imzml_path) if detected is not None: _log(f" MALDI pixel size from imzML metadata: {detected} um/px") _p_probe = ImzMLParser(imzml_path) _maldi_h = _p_probe.imzmldict.get('max count of pixels y', 1) _maldi_w = _p_probe.imzmldict.get('max count of pixels x', 1) _he_thumb_w = he_phys_w_um / detected _he_thumb_h = he_phys_h_um / detected if _he_thumb_w >= _maldi_w and _he_thumb_h >= _maldi_h: maldi_pixel_um = detected _log(f" Validated: H&E thumbnail ({_he_thumb_w:.0f}x{_he_thumb_h:.0f} px) " f">= MALDI ({_maldi_w}x{_maldi_h} px)") else: _log(f" WARNING: imzML pixel size {detected} um makes H&E thumbnail " f"({_he_thumb_w:.0f}x{_he_thumb_h:.0f} px) smaller than MALDI " f"({_maldi_w}x{_maldi_h} px) -- likely wrong.") for candidate in [10.0, 20.0, 50.0, 100.0, 200.0]: _cw = he_phys_w_um / candidate _ch = he_phys_h_um / candidate if _cw >= _maldi_w * 0.5 and _ch >= _maldi_h * 0.5: maldi_pixel_um = candidate _log(f" Auto-selected maldi_pixel_um={candidate} um/px") break if maldi_pixel_um is None: maldi_pixel_um = 10.0 _log(f" Falling back to maldi_pixel_um=10.0 um/px.") else: maldi_pixel_um = 10.0 _log(f" Pixel size not in imzML, defaulting to {maldi_pixel_um} um/px.") else: _log(f" MALDI pixel size (supplied): {maldi_pixel_um} um/px") _log(f" maldi_pixel_um={maldi_pixel_um} he_pixel_um={he_pixel_um:.4f}") # ------------------------------------------------------------------ # 2. MALDI crop offsets # ------------------------------------------------------------------ _log("Computing MALDI crop offsets ...") tic_probe = np.nansum( np.stack([getimage(pk, path=imzml_path) for pk in peaks[:5]], axis=-1), axis=-1, ) crop_r, crop_c = _crop_offsets(tic_probe) _log(f" Crop: row={crop_r}, col={crop_c}") del tic_probe gc.collect() # ------------------------------------------------------------------ # 3. Load spectra in chunks # ------------------------------------------------------------------ _log(f"Loading {len(peaks)} ion images (chunk={spectra_chunk_size}) ...") spectra_all = _load_spectra( imzml_path, peaks, chunk_size=spectra_chunk_size, crop_r=crop_r, crop_c=crop_c, ) _log(f" spectra_all: {spectra_all.shape} ({spectra_all.nbytes/1e6:.0f} MB)") # ------------------------------------------------------------------ # 4. MALDI registration image # ------------------------------------------------------------------ _log("Preparing MALDI template ...") maldi_tic = spectra_all.sum(axis=-1).astype(np.float32) maldi_gray = _maldi_to_grayscale(maldi_tic) del maldi_tic gc.collect() # ------------------------------------------------------------------ # 5. H&E at MALDI native resolution # ------------------------------------------------------------------ _log(f"Loading H&E at {maldi_pixel_um} um/px ...") he_img, loaded_mpp = _load_he_at_resolution(he_path, maldi_pixel_um, he_pixel_um) _log(f" H&E: {he_img.width}x{he_img.height} ({he_img.width*he_img.height*3/1e6:.0f} MB)") he_reg_w = he_img.width he_reg_h = he_img.height # ------------------------------------------------------------------ # 6. H&E registration image # ------------------------------------------------------------------ _log("Preparing H&E search image ...") he_gray = _he_to_grayscale(he_img) # ------------------------------------------------------------------ # 7. Two-pass rotation + translation search (cv2-based) # best_idx is now expressed in the same canvas coordinate system # as the output canvas built in step 8. # ------------------------------------------------------------------ _log("Running registration ...") best_rot, best_idx = _register( he_gray, maldi_gray, src_w=he_reg_w, src_h=he_reg_h, coarse_step=coarse_rotation_step, fine_range=fine_rotation_range, fine_step=fine_rotation_step, buffer_px=buffer_px, ) del he_gray, maldi_gray gc.collect() # ------------------------------------------------------------------ # 8. Build H&E output canvas. # Uses _build_affine_matrix with identical arguments to step 7, # so the canvas geometry is byte-for-byte the same as the search # canvas — best_idx is directly valid here. # ------------------------------------------------------------------ _log("Building H&E output canvas ...") he_canvas, M_stored, canvas_pr, canvas_pc = _build_affine_and_canvas( he_img = he_img, src_w = he_reg_w, src_h = he_reg_h, rotation_deg = best_rot, buffer_px = buffer_px, ) del he_img gc.collect() # ------------------------------------------------------------------ # 9. Transform annotations using the same affine matrix. # ------------------------------------------------------------------ annotation_gdf = None if geojson_path is not None: _log(f"Transforming annotations: {geojson_path} ...") annotation_gdf = _transform_geojson( geojson_path = geojson_path, he_pixel_um = he_pixel_um, reg_mpp = loaded_mpp, M_stored = M_stored, img_upscaling = img_upscaling, classification_key = geojson_classification_key, ) unique = annotation_gdf[geojson_classification_key].unique().tolist() _log(f" {len(annotation_gdf)} annotations | classes: {unique}") # ------------------------------------------------------------------ # 10. Assemble SpatialData # ------------------------------------------------------------------ _log("Building SpatialData ...") sdata = _build_spatialdata( spectra_all = spectra_all, peaks = peaks, maldi_pixel_um = maldi_pixel_um, he_canvas = he_canvas, maldi_offset_in_canvas = best_idx, reg_mpp = loaded_mpp, crop_r = crop_r, crop_c = crop_c, img_upscaling = img_upscaling, ) if annotation_gdf is not None: ann_shapes = ShapesModel.parse( annotation_gdf, transformations={"global": Identity()}, ) ann_shapes[geojson_classification_key] = ann_shapes[geojson_classification_key].astype("category") sdata.shapes[geojson_shapes_key] = ann_shapes _log(f" Annotations added -> sdata.shapes['{geojson_shapes_key}']") # ------------------------------------------------------------------ # 11. Store registration transform. # ------------------------------------------------------------------ sdata["maldi_adata"].uns["he_transform"] = { "rotation_deg": float(best_rot), "maldi_offset": [int(best_idx[0]), int(best_idx[1])], "he_pixel_um": float(he_pixel_um), "maldi_pixel_um": float(maldi_pixel_um), "reg_mpp": float(loaded_mpp), "buffer_px": int(buffer_px), "img_upscaling": int(img_upscaling), "canvas_shape": list(he_canvas.shape[:2]), "he_reg_size": [int(he_reg_h), int(he_reg_w)], "canvas_placement": [int(canvas_pr), int(canvas_pc)], "affine_matrix": M_stored.tolist(), } _log(f" Transform stored: rotation={best_rot} " f"he_reg_size={[he_reg_h, he_reg_w]} " f"canvas_placement={[canvas_pr, canvas_pc]}") _log("Done.") return sdata