import numpy as np
import anndata as ad
import pandas as pd
import geopandas as gpd
import tifffile as tiff
import csv
import pkg_resources
import os
import warnings
from pyimzml.ImzMLParser import ImzMLParser, getionimage
from joblib import Parallel, delayed
from functools import partial
from shapely.geometry import box
from anndata import AnnData
from spatialdata import SpatialData
from spatialdata.models import PointsModel, Image2DModel, TableModel, ShapesModel
from spatialdata.transformations import Identity
[docs]
def parmap(f, X, nprocs=None):
"""
Parallel map using joblib (more robust for Jupyter).
Parameters
----------
f : callable
Function to apply to each element
X : iterable
Input data
nprocs : int, optional
Number of processes (default: -1, all CPUs)
Returns
-------
list
Results in same order as input
"""
if nprocs is None:
nprocs = -1 # Use all CPUs
return Parallel(n_jobs=nprocs, backend='loky')(
delayed(f)(x) for x in X
)
[docs]
def getimage(peak, path, tol = 0.1,reduce_func = sum):
p = ImzMLParser(path) #individual file pointers otherwise parsing is corrupted
return getionimage(p, peak, tol=tol, reduce_func=reduce_func)
[docs]
def rd_peaks(fn):
with open(fn, newline='', encoding='utf-8-sig') as f:
# Detect delimiter by sniffing the first line
sample = f.read(4096)
f.seek(0)
try:
dialect = csv.Sniffer().sniff(sample, delimiters=',\t ')
except csv.Error:
dialect = csv.excel # fallback to comma
reader = csv.reader(f, dialect)
header = next(reader)
# 1. Try to find a column whose name contains 'm/z' (case-insensitive)
mz_col = next(
(i for i, h in enumerate(header) if 'm/z' in h.lower()),
None
)
if mz_col is not None:
col_idx = mz_col
elif len(header) == 1:
warnings.warn(
"No 'm/z' column found. Attempting to infer column with numeric values."
)
# Single-column file — read col 0 but skip the header we already consumed
col_idx = 0
else:
warnings.warn(
"No 'm/z' column found. Attempting to infer column with numeric values."
)
# Multi-column, no m/z header found — fall back to column index 1
col_idx = 1
data = []
for row in reader:
if col_idx >= len(row):
continue
val = row[col_idx].strip().strip('"')
try:
data.append(float(val))
except ValueError:
continue # skip non-numeric rows
return data
[docs]
def rd_peaks_from_package():
# Try to get the file from the package
peaks_path = pkg_resources.resource_filename('goatpy', 'data/PEAKS.csv')
if not os.path.exists(peaks_path):
raise FileNotFoundError(f"PEAKS.csv not found at {peaks_path}")
with open(peaks_path, 'r') as f:
data = []
f.readline() # skip header
for line in f:
ss = line.split()
if ss[0].strip('"') == 'M':
continue
data.append(float(ss[1]))
return data
[docs]
def glyco_spatialdata(imzml_path, peaks_path = None, tol = 0.1, pixel_size = 20, reduce_func = sum):
# Load Peaks
if peaks_path is None:
peaks = rd_peaks_from_package()
else:
peaks = rd_peaks(peaks_path)
peaks = sorted(peaks)
# Load ImzML data
getimg = partial(getimage, path=imzml_path, tol = tol, reduce_func = reduce_func)
spectra_all = np.stack(
parmap(getimg, peaks, 10),
axis=-1
)
# Load Spatial Info
p = ImzMLParser(imzml_path)
coords = np.array(p.coordinates)[:, :2] # (x, y)
coords = coords - 1 # convert from 1-based to 0-based indexing
# Create AnnData Object
spectra_flat = np.array([spectra_all[y-1, x-1, :] for x, y in coords])
anndata = ad.AnnData(spectra_flat, dtype=np.float32)
anndata.var_names = np.array([str(pk) for pk in peaks])
anndata.obs_names = np.array(list(map(str, range(len(coords)))))
anndata.obs["full_x"] = coords[:, 0]
anndata.obs["full_y"] = coords[:, 1]
anndata.obs["x"] = anndata.obs["full_x"] - anndata.obs["full_x"].min()
anndata.obs["y"] = anndata.obs["full_y"] - anndata.obs["full_y"].min()
anndata.obsm["spatial"] = np.column_stack([anndata.obs["x"], anndata.obs["y"]])
# Calculate Total Ion Count (TIC)
anndata.obs["MPI"] = np.ravel(anndata.X.sum(axis=1))
# Create SpatialData Object
coords = pd.DataFrame({
"x": [c for c in anndata.obs["x"]],
"y": [c for c in anndata.obs["y"]],
})
coords["instance_id"] = coords.index.astype(str) # unique ID for each pixel
coords["region"] = "pixels" # must exist for TableModel
df = pd.concat(
[
coords.reset_index(drop=True),
pd.DataFrame(anndata.X, columns=("mz-" + anndata.var.index))
],
axis=1
)
points = PointsModel.parse(df)
gdf = centroids_to_pixel_squares(df, x_col="x", y_col="y", pixel_size=1.0)
shapes = ShapesModel.parse(
gdf[["instance_id", "geometry"]],
transformations={"global": Identity()},
)
sdata = SpatialData(points={"centroids": points},
shapes={"pixels": shapes})
adata = AnnData(
X=anndata.X,
obs=coords, # contains x, y, point_id, region
var=pd.DataFrame(index=("mz-" + anndata.var.index))
)
adata.obs = pd.concat(
[
adata.obs.reset_index(drop=True),
anndata.obs.drop(columns=["x", "y"]).iloc[:adata.n_obs].reset_index(drop=True)
],
axis=1
)
coords = np.array(adata.obs[["x", "y"]])
adata.obsm["spatial"] = coords
adata.obs["instance_id"] = sdata["pixels"].index
adata.obs["region"] = "pixels"
adata.obs["region"].astype("category")
table = TableModel.parse(
adata,
region="pixels", # name of your PointsModel
region_key="region", # must exist in adata.obs
instance_key="instance_id" # unique per row
)
# --- 7. Add to SpatialData ---
sdata.tables["maldi_adata"] = table
sdata["maldi_adata"].uns["maldi_path"] = imzml_path
return sdata
[docs]
def centroids_to_pixel_squares(df, x_col="x", y_col="y", pixel_size=1.0):
half = pixel_size / 2
geometries = [
box(x - half, y - half, x + half, y + half)
for x, y in zip(df[x_col], df[y_col])
]
gdf = gpd.GeoDataFrame(
df.copy(),
geometry=geometries,
)
return gdf
[docs]
def ihc_spatialdata(
ihc_image_path,
channel_names=None,
image_key="ihc_image"
):
"""
Generate a SpatialData object from a multichannel IHC TIFF/OME-TIFF image.
Parameters
----------
ihc_image_path : str
Path to TIFF/OME-TIFF image.
channel_names : list, optional
Names of image channels.
Example:
["sc_405", "CF_488", "CF_561", "DIC"]
image_key : str
Name of image key in SpatialData object.
Returns
-------
SpatialData
"""
try:
# Read TIFF / OME-TIFF
img = tiff.imread(ihc_image_path)
print(f"Loaded image with shape: {img.shape}")
# RGB image (Y, X, 3)
if img.ndim == 3 and img.shape[-1] in [3, 4]:
# Convert to channel-first (C, Y, X)
image_cyx = np.transpose(img, (2, 0, 1))
if channel_names is None:
channel_names = [f"channel_{i}" for i in range(image_cyx.shape[0])]
# Multichannel image already channel-first (C, Y, X)
elif img.ndim == 3:
image_cyx = img
if channel_names is None:
channel_names = [f"channel_{i}" for i in range(image_cyx.shape[0])]
# Single-channel image
elif img.ndim == 2:
image_cyx = img[np.newaxis, :, :]
if channel_names is None:
channel_names = ["channel_0"]
else:
raise ValueError(
f"Unsupported image dimensions: {img.shape}"
)
img_model = Image2DModel.parse(
image_cyx,
dims=("c", "y", "x"),
c_coords=channel_names,
transformations={"global": Identity()},
)
sdata = SpatialData(
images={image_key: img_model}
)
print("Successfully generated SpatialData object")
return sdata
except Exception as e:
raise RuntimeError(
f"Failed to generate SpatialData object: {str(e)}"
)