Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,11 @@ _version.py

# pixi
pixi.lock

# Local sandbox data + historical prototype notebooks (Sandbox.ipynb is the
# active one and is tracked; the others are kept locally for reference only).
/sandbox_data/
/Sandbox.anywidget-v0.ipynb
/Sandbox.ipympl-v0.ipynb
/verify_ssh_annotate.ipynb
Sandbox.ipynb
67 changes: 46 additions & 21 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ dependencies = [
"scikit-learn",
"spatialdata>=0.3",
]
optional-dependencies.interactive = [
"anybioimage>=0.3,<0.4",
"anywidget",
"ipykernel",
"ipywidgets",
]
urls.Documentation = "https://spatialdata.scverse.org/projects/plot/en/latest/index.html"
urls.Home-page = "https://github.com/scverse/spatialdata-plot.git"
urls.Source = "https://github.com/scverse/spatialdata-plot.git"
Expand Down Expand Up @@ -61,7 +67,6 @@ doc = [
"sphinxcontrib-katex",
"sphinxext-opengraph",
]

[tool.hatch]
build.hooks.vcs.version-file = "_version.py"
build.targets.wheel.packages = [ "src/spatialdata_plot" ]
Expand All @@ -86,29 +91,49 @@ envs.hatch-test.scripts.cov-report = [ "coverage report", "coverage xml -o cover
metadata.allow-direct-references = true
version.source = "vcs"

[tool.pixi]
workspace.channels = [ "conda-forge" ]
workspace.platforms = [ "linux-64", "osx-arm64" ]
dependencies.python = ">=3.11"
pypi-dependencies.spatialdata-plot = { path = ".", editable = true }
tasks.format = "ruff format ."
tasks.kernel-install = 'python -m ipykernel install --user --name pixi-dev --display-name "sdata-plot (dev)"'
tasks.lab = "jupyter lab"
tasks.lint = "ruff check ."
tasks.pre-commit-install = "pre-commit install"
tasks.pre-commit-run = "pre-commit run --all-files"
tasks.test = "pytest -v --color=yes --tb=short --durations=10"
[tool.pixi.workspace]
channels = [ "conda-forge" ]
platforms = [ "linux-64", "osx-arm64" ]

[tool.pixi.dependencies]
python = ">=3.11"

[tool.pixi.pypi-dependencies]
spatialdata-plot = { path = ".", editable = true }

# When the `interactive` feature is active, install the package with the
# `interactive` PyPI extra (anywidget, ipykernel, ipywidgets) so the pixi
# env mirrors what `pip install spatialdata-plot[interactive]` would give.
[tool.pixi.feature.interactive.pypi-dependencies]
spatialdata-plot = { path = ".", editable = true, extras = [ "interactive" ] }

[tool.pixi.tasks]
format = "ruff format ."
kernel-install = 'python -m ipykernel install --user --name pixi-dev --display-name "sdata-plot (dev)"'
kernel-install-interactive = 'python -m ipykernel install --user --name sdata-plot-interactive --display-name "sdata-plot (interactive)"'
lab = "jupyter lab"
lint = "ruff check ."
pre-commit-install = "pre-commit install"
pre-commit-run = "pre-commit run --all-files"
test = "pytest -v --color=yes --tb=short --durations=10"

# for gh-actions
feature.py311.dependencies.python = "3.11.*"
feature.py313.dependencies.python = "3.13.*"
[tool.pixi.feature.py311.dependencies]
python = "3.11.*"

[tool.pixi.feature.py313.dependencies]
python = "3.13.*"

[tool.pixi.environments]
# 3.13 lane
environments.default = { features = [ "py313" ], solve-group = "py313" }
default = { features = [ "py313" ], solve-group = "py313" }
# 3.11 lane (for gh-actions)
environments.dev-py311 = { features = [ "dev", "test", "py311" ], solve-group = "py311" }
environments.dev-py313 = { features = [ "dev", "test", "py313" ], solve-group = "py313" }
environments.docs-py311 = { features = [ "doc", "py311" ], solve-group = "py311" }
environments.docs-py313 = { features = [ "doc", "py313" ], solve-group = "py313" }
environments.test-py313 = { features = [ "test", "py313" ], solve-group = "py313" }
dev-py311 = { features = [ "dev", "test", "py311" ], solve-group = "py311" }
dev-py313 = { features = [ "dev", "test", "py313" ], solve-group = "py313" }
dev-interactive-py313 = { features = [ "dev", "test", "interactive", "py313" ], solve-group = "py313" }
docs-py311 = { features = [ "doc", "py311" ], solve-group = "py311" }
docs-py313 = { features = [ "doc", "py313" ], solve-group = "py313" }
test-py313 = { features = [ "test", "py313" ], solve-group = "py313" }

[tool.ruff]
line-length = 120
Expand Down
132 changes: 132 additions & 0 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,141 @@ def _copy(
tables=self._sdata.tables if tables is None else tables,
)
sdata.plotting_tree = self._sdata.plotting_tree if hasattr(self._sdata, "plotting_tree") else OrderedDict()
sdata._source_sdata = getattr(self._sdata, "_source_sdata", self._sdata)

return sdata

def annotate(
self,
*,
coordinate_systems: str | None = None,
point_radius_frac: float = 0.005,
figsize: tuple[float, float] = (7, 7),
dpi: int = 120,
) -> Any:
"""Terminal step on a render chain: drop the plot into an interactive annotator.

Renders the accumulated ``plotting_tree`` (so any ``render_images`` /
``render_shapes`` / ``render_points`` / ``render_labels`` overlays composed
upstream of this call appear in the annotation canvas), then hands the
rasterised figure to a ``BioImageViewer`` widget. The user draws
rectangles, polygons, and points on the canvas, types a name, and clicks
*Save* — the shapes are converted from canvas-pixel space to the chosen
coordinate system and stored in ``sdata.shapes[<name>]`` with an
``Identity`` transformation in that CS. Points are stored as small
circle polygons (radius = ``point_radius_frac`` of the rendered image's
CS extent) so the resulting ``ShapesModel`` is uniform-type.

Single coordinate system only. If the chain spans more than one CS, or
none can be inferred, raises ``ValueError``.

Requires the ``interactive`` extra: ``pip install 'spatialdata-plot[interactive]'``.

Parameters
----------
coordinate_systems :
Coordinate system to render and resolve drawn shapes against.
Drawn shapes are stored with an ``Identity`` transformation in this
CS. If ``None`` and the SpatialData has exactly one CS, that one is
used; otherwise this argument is required.
point_radius_frac :
Radius of the circle polygon used to store each point, expressed as
a fraction of the rendered image's CS extent. Default 0.005 (0.5%).
figsize :
Matplotlib figure size used for the underlying rasterisation. The
same value affects the canvas resolution alongside ``dpi``.
dpi :
DPI of the rasterised figure. Combined with ``figsize`` this sets
the pixel resolution the annotator works in.

Returns
-------
InteractiveSession
The session object, with the widget already displayed. Holding the
reference keeps the underlying ``BioImageViewer`` alive across cell
re-runs; usually you can ignore the return value.

Raises
------
ValueError
If no single coordinate system can be resolved.
ImportError
If the ``interactive`` extra is not installed.

Examples
--------
>>> import spatialdata_plot # noqa: F401 registers .pl
>>> (
... sdata.pl
... .render_images(element="he")
... .pl.render_shapes(element="cells", outline_color="red")
... .pl.annotate()
... )
>>> # ... user draws and clicks Save with name "tumor" ...
>>> sdata.shapes["tumor"]
"""
try:
from spatialdata_plot.pl.interactive._session import _InteractiveSession
except ImportError as exc:
raise ImportError(
"sdata.pl.annotate() requires the `interactive` extra. "
"Install with: pip install 'spatialdata-plot[interactive]'"
) from exc

import io as _io

from PIL import Image as _Image

available_cs = list(self._sdata.coordinate_systems)
if coordinate_systems is None:
if len(available_cs) != 1:
raise ValueError(
"annotate() needs exactly one coordinate system. "
f"SpatialData has {len(available_cs)}: {available_cs!r}. "
"Pass coordinate_systems=<name> explicitly."
)
cs = available_cs[0]
else:
if isinstance(coordinate_systems, list):
if len(coordinate_systems) != 1:
raise ValueError(f"annotate() supports a single coordinate system; got {coordinate_systems!r}.")
cs = coordinate_systems[0]
else:
cs = coordinate_systems
if cs not in available_cs:
raise ValueError(f"Unknown coordinate system {cs!r}. Available: {available_cs!r}")

fig = plt.figure(figsize=figsize, dpi=dpi)
try:
ax = fig.add_axes([0, 0, 1, 1])
self.show(coordinate_systems=cs, ax=ax)
xlim = ax.get_xlim()
ylim = ax.get_ylim()
ax.set_axis_off()
# set_aspect("equal") inside show() can shrink the axes box so the
# figure has blank padding around the data. Crop the saved PNG to
# the axes bbox so PNG pixels map 1:1 to (xlim, ylim) and the
# px→cs transform in _commit.py stays correct.
fig.canvas.draw()
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
buf = _io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi, bbox_inches=bbox, pad_inches=0)
finally:
plt.close(fig)
rgb = np.asarray(_Image.open(buf).convert("RGB"))

target_sdata = getattr(self._sdata, "_source_sdata", self._sdata)
session = _InteractiveSession(
sdata=target_sdata,
coordinate_system=cs,
rgb=rgb,
xlim=tuple(xlim),
ylim=tuple(ylim),
point_radius_frac=point_radius_frac,
)
session.show()
return session

@_deprecation_alias(elements="element", version="0.3.0")
def render_shapes(
self,
Expand Down
11 changes: 11 additions & 0 deletions src/spatialdata_plot/pl/interactive/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Interactive region selection on top of a rendered spatialdata-plot figure.

Use via :meth:`spatialdata_plot.pl.basic.PlotAccessor.annotate`:

>>> import spatialdata_plot # noqa: F401 registers .pl
>>> sdata.pl.render_images(element="he").pl.annotate()
"""

from __future__ import annotations

__all__: list[str] = []
97 changes: 97 additions & 0 deletions src/spatialdata_plot/pl/interactive/_commit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Convert anybioimage canvas shapes into CS-coord shapely geometries."""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

from shapely.geometry import Point, Polygon, box

PxToCs = Callable[[float, float], tuple[float, float]]


def _make_px_to_cs(xmin: float, xmax: float, y_lo: float, y_hi: float, image_w: int, image_h: int) -> PxToCs:
"""Build an affine mapping (px_x, px_y) → (cs_x, cs_y).

The y_lo/y_hi are the sorted ylim values; image_h pixels map linearly
between them. matplotlib image axes with ``origin='upper'`` return
reversed ylim — sorting normalises that.
"""
dx = xmax - xmin
dy = y_hi - y_lo

def px_to_cs(x_px: float, y_px: float) -> tuple[float, float]:
return (xmin + (x_px / image_w) * dx, y_lo + (y_px / image_h) * dy)

return px_to_cs


def _roi_to_polygon(roi: dict[str, Any], px_to_cs: PxToCs) -> Polygon | None:
"""ROI dict ``{x, y, width, height}`` → axis-aligned rectangle Polygon."""
try:
x0, y0 = px_to_cs(float(roi["x"]), float(roi["y"]))
x1, y1 = px_to_cs(float(roi["x"]) + float(roi["width"]), float(roi["y"]) + float(roi["height"]))
except (KeyError, TypeError, ValueError):
return None
poly = box(min(x0, x1), min(y0, y1), max(x0, x1), max(y0, y1))
return poly if not poly.is_empty else None


def _polygon_to_polygon(poly: dict[str, Any], px_to_cs: PxToCs) -> Polygon | None:
"""Polygon dict ``{id, points: [{x, y}, ...]}`` → shapely Polygon (≥3 verts)."""
pts = poly.get("points") or []
try:
cs_verts = [px_to_cs(float(p["x"]), float(p["y"])) for p in pts]
except (KeyError, TypeError, ValueError):
return None
if len(cs_verts) < 3:
return None
geom = Polygon(cs_verts)
return geom if not geom.is_empty else None


def _point_to_circle(pt: dict[str, Any], px_to_cs: PxToCs, radius: float) -> Polygon | None:
"""Point dict ``{x, y}`` → circle Polygon of the given CS-units radius.

Stored as a polygon so the resulting ShapesModel is uniform-type and
doesn't need a ``radius`` column.
"""
try:
cx, cy = px_to_cs(float(pt["x"]), float(pt["y"]))
except (KeyError, TypeError, ValueError):
return None
geom = Point(cx, cy).buffer(radius)
return geom if not geom.is_empty else None


def collect_geoms_from_viewer(
viewer: Any,
*,
xmin: float,
xmax: float,
y_lo: float,
y_hi: float,
image_w: int,
image_h: int,
point_radius: float,
) -> list[Polygon]:
"""Read the viewer's three shape stores and convert each entry to a CS-coord Polygon.

Order of returned geometries: ROIs first, then polygons, then points. Invalid
entries (missing keys, degenerate geometry) are silently skipped.
"""
px_to_cs = _make_px_to_cs(xmin, xmax, y_lo, y_hi, image_w, image_h)
geoms: list[Polygon] = []
for roi in viewer._rois_data or []:
g = _roi_to_polygon(roi, px_to_cs)
if g is not None:
geoms.append(g)
for poly in viewer._polygons_data or []:
g = _polygon_to_polygon(poly, px_to_cs)
if g is not None:
geoms.append(g)
for pt in viewer._points_data or []:
g = _point_to_circle(pt, px_to_cs, point_radius)
if g is not None:
geoms.append(g)
return geoms
Loading
Loading