Source code for goatpy.batch_correction

"""
batch_correction.py
===================

Batch correction for MALDI SpatialData objects.

Supports two methods:

* Harmony
    Operates on a PCA embedding.
    Corrected embeddings are stored in:

        adata.obsm["X_pca_harmony"]

    The raw intensity matrix (adata.X) is NOT modified.

* ComBat
    Operates directly on adata.X.
    Corrected intensities overwrite:

        adata.X


---------------------------------------------------------------------------
IMPORTANT
---------------------------------------------------------------------------

This function assumes the data has already been normalised
(e.g. TIC or RMS normalisation).

Applying batch correction to raw counts will produce misleading results.


---------------------------------------------------------------------------
USAGE
---------------------------------------------------------------------------

# Option 1 — merge inside batch_correction
merged = batch_correction(
    sdatas=[sdata1, sdata2],
    batch_names=["A", "B"],
    method="harmony",
)

# Option 2 — use an already merged object
merged = batch_correction(
    pre_merged=merged_sdata,
    method="combat",
)

"""

from __future__ import annotations

import os
import warnings

from typing import Literal, Optional

import numpy as np
import pandas as pd

from spatialdata import SpatialData

import goatpy as gp


# -----------------------------------------------------------------------------
# Logging helper
# -----------------------------------------------------------------------------

def _log(msg: str) -> None:

    try:
        import psutil

        rss = psutil.Process(os.getpid()).memory_info().rss / 1e9

        print(f"[{rss:.2f}GB] {msg}")

    except ImportError:

        print(msg)


# -----------------------------------------------------------------------------
# Harmony
# -----------------------------------------------------------------------------

def _run_harmony(
    merged: SpatialData,
    table_name: str,
    pcs: int,
    batch_col: str,
    random_state: int,
) -> SpatialData:

    try:
        import harmonypy
    except ImportError as e:
        raise ImportError(
            "harmonypy is required for Harmony batch correction.\n"
            "Install with:\n"
            "pip install harmonypy"
        ) from e

    import anndata as ad
    import scanpy as sc

    from scipy.sparse import issparse
    from sklearn.preprocessing import StandardScaler

    adata = merged.tables[table_name]

    _log(
        f"Harmony input: "
        f"{adata.n_obs:,} pixels × {adata.n_vars:,} features"
    )

    # -------------------------------------------------------------------------
    # Convert matrix
    # -------------------------------------------------------------------------

    X_raw = (
        adata.X.toarray()
        if issparse(adata.X)
        else np.asarray(adata.X, dtype=np.float64)
    )

    X_scaled = np.zeros_like(X_raw, dtype=np.float64)

    # -------------------------------------------------------------------------
    # Scale each batch independently
    # -------------------------------------------------------------------------

    unique_batches = (
        adata.obs[batch_col]
        .astype(str)
        .unique()
    )

    _log(f"Detected {len(unique_batches)} batches")

    for batch in unique_batches:

        mask = adata.obs[batch_col].astype(str) == str(batch)

        idx = np.where(mask)[0]

        if idx.size == 0:
            continue

        scaler = StandardScaler()

        X_scaled[idx] = scaler.fit_transform(X_raw[idx])

        _log(
            f"Scaled batch '{batch}' "
            f"({idx.size:,} pixels)"
        )

    # -------------------------------------------------------------------------
    # PCA object
    # -------------------------------------------------------------------------

    pca_adata = ad.AnnData(
        X=X_scaled.astype(np.float32),
        obs=adata.obs.copy(),
    )

    _log(f"Running PCA ({pcs} PCs)")

    sc.pp.pca(
        pca_adata,
        n_comps=pcs,
        random_state=random_state,
    )

    # -------------------------------------------------------------------------
    # Harmony
    # -------------------------------------------------------------------------

    _log("Running Harmony")

    harmony_out = harmonypy.run_harmony(
        pca_adata.obsm["X_pca"].astype(np.float64),
        pca_adata.obs,
        batch_col,
        max_iter_harmony=20,
    )

    corrected = harmony_out.Z_corr

    if corrected.shape[0] != adata.n_obs:
        corrected = corrected.T

    # -------------------------------------------------------------------------
    # Store embeddings
    # -------------------------------------------------------------------------

    adata.obsm["X_pca"] = pca_adata.obsm["X_pca"]

    adata.obsm["X_pca_harmony"] = corrected.astype(np.float32)

    _log(
        "Harmony complete "
        f"(embedding shape={corrected.shape})"
    )

    return merged


# -----------------------------------------------------------------------------
# ComBat
# -----------------------------------------------------------------------------

def _run_combat(
    merged: SpatialData,
    table_name: str,
    batch_col: str,
    covariates: Optional[list[str]],
) -> SpatialData:

    try:
        import scanpy as sc
    except ImportError as e:
        raise ImportError(
            "scanpy is required for ComBat batch correction.\n"
            "Install with:\n"
            "pip install scanpy"
        ) from e

    adata = merged.tables[table_name]

    X = np.asarray(adata.X, dtype=np.float64)

    # -------------------------------------------------------------------------
    # Replace invalid values
    # -------------------------------------------------------------------------

    if not np.all(np.isfinite(X)):

        n_bad = int(np.sum(~np.isfinite(X)))

        warnings.warn(
            f"Found {n_bad:,} non-finite values in adata.X.\n"
            "Replacing with 0 before ComBat.",
            stacklevel=2,
        )

        X = np.nan_to_num(
            X,
            nan=0.0,
            posinf=0.0,
            neginf=0.0,
        )

        adata.X = X.astype(np.float32)

    # -------------------------------------------------------------------------
    # Validate covariates
    # -------------------------------------------------------------------------

    covariates_arg = None

    if covariates is not None:

        missing = [
            c for c in covariates
            if c not in adata.obs.columns
        ]

        if missing:
            raise KeyError(
                "Missing covariates in adata.obs:\n"
                f"{missing}"
            )

        covariates_arg = covariates

        _log(
            f"Preserving covariates: "
            f"{covariates_arg}"
        )

    # -------------------------------------------------------------------------
    # Run ComBat
    # -------------------------------------------------------------------------

    _log(
        f"Running ComBat "
        f"(matrix shape={X.shape})"
    )

    sc.pp.combat(
        adata,
        key=batch_col,
        covariates=covariates_arg,
        inplace=True,
    )

    # -------------------------------------------------------------------------
    # Remove negatives
    # -------------------------------------------------------------------------

    X_corrected = np.asarray(
        adata.X,
        dtype=np.float32,
    )

    n_neg = int((X_corrected < 0).sum())

    if n_neg > 0:

        _log(
            f"Clipping {n_neg:,} negative values to 0"
        )

        X_corrected = np.clip(
            X_corrected,
            0.0,
            None,
        )

        adata.X = X_corrected

    _log("ComBat complete")

    return merged


# -----------------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------------

[docs] def batch_correction( sdatas: Optional[list[SpatialData]] = None, batch_names: Optional[list[str]] = None, pre_merged: Optional[SpatialData] = None, method: Literal["harmony", "combat"] = "harmony", pcs: int = 30, covariates: Optional[list[str]] = None, table_name: str = "maldi_adata", batch_col: str = "batch", feature_join: str = "inner", random_state: int = 42, ) -> SpatialData: """ Merge SpatialData objects and apply batch correction. Parameters ---------- sdatas SpatialData objects to merge. batch_names Batch labels corresponding to sdatas. pre_merged Already merged SpatialData object. If provided, merge_spatialdata is skipped. method "harmony" or "combat" pcs Number of PCs for Harmony. covariates Covariates to preserve during ComBat. table_name AnnData table name. batch_col obs column containing batch labels. feature_join "inner" or "outer" random_state Random seed for Harmony. Returns ------- SpatialData """ # ------------------------------------------------------------------------- # Warning # ------------------------------------------------------------------------- warnings.warn( "\n" "batch_correction assumes the data has already " "been normalised.\n" "Running on raw counts may produce misleading results.", UserWarning, stacklevel=2, ) # ------------------------------------------------------------------------- # Validate method # ------------------------------------------------------------------------- method = method.lower() if method not in ("harmony", "combat"): raise ValueError( "method must be either:\n" "'harmony' or 'combat'" ) # ------------------------------------------------------------------------- # Validate inputs # ------------------------------------------------------------------------- if pre_merged is None: if sdatas is None: raise ValueError( "Either 'sdatas' or 'pre_merged' " "must be provided." ) if batch_names is None: raise ValueError( "'batch_names' must be provided " "when using 'sdatas'." ) if len(sdatas) != len(batch_names): raise ValueError( "'sdatas' and 'batch_names' " "must have the same length." ) # ------------------------------------------------------------------------- # Merge or use pre-merged # ------------------------------------------------------------------------- if pre_merged is not None: if not isinstance(pre_merged, SpatialData): raise TypeError( "pre_merged must be a SpatialData object." ) merged = pre_merged _log( "Using pre-merged SpatialData object" ) else: _log( f"Merging {len(sdatas)} SpatialData objects" ) merged = gp.merge_spatialdata( sdatas=sdatas, batch_names=batch_names, table_name=table_name, feature_join=feature_join, ) # ------------------------------------------------------------------------- # Validate table # ------------------------------------------------------------------------- if table_name not in merged.tables: raise KeyError( f"'{table_name}' not found in merged.tables" ) adata = merged.tables[table_name] # ------------------------------------------------------------------------- # Validate batch column # ------------------------------------------------------------------------- if batch_col not in adata.obs.columns: raise KeyError( f"'{batch_col}' not found in adata.obs" ) # ------------------------------------------------------------------------- # Summary # ------------------------------------------------------------------------- _log( f"Input matrix: " f"{adata.n_obs:,} pixels × " f"{adata.n_vars:,} features" ) unique_batches = ( adata.obs[batch_col] .astype(str) .unique() ) _log( f"Detected {len(unique_batches)} batches" ) # ------------------------------------------------------------------------- # Run correction # ------------------------------------------------------------------------- if method == "harmony": merged = _run_harmony( merged=merged, table_name=table_name, pcs=pcs, batch_col=batch_col, random_state=random_state, ) else: merged = _run_combat( merged=merged, table_name=table_name, batch_col=batch_col, covariates=covariates, ) # ------------------------------------------------------------------------- # Provenance # ------------------------------------------------------------------------- adata.uns["batch_correction"] = { "method": method, "pcs": pcs if method == "harmony" else None, "covariates": ( covariates if method == "combat" else None ), "batch_col": batch_col, "feature_join": feature_join, "n_batches": len(unique_batches), "used_pre_merged": pre_merged is not None, } _log( f"batch_correction complete " f"(method='{method}')" ) return merged