Skip to content

Commit 855ed32

Browse files
timtreisclaude
andcommitted
Implement render_graph for spatial connectivity visualization
Adds `sdata.pl.render_graph()` to render spatial graph edges from adjacency matrices stored in table.obsp, using element centroids for coordinates. Supports shapes, points, and labels elements via spatialdata.get_centroids(). Key features: - connectivity_key accepts full obsp key or prefix (auto-resolves) - element and table_name auto-discovered when unambiguous - groups + group_key filtering (both-endpoints semantics) - Rasterized LineCollection for performance - No networkx dependency (direct sparse matrix to line segments) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 433df0f commit 855ed32

7 files changed

Lines changed: 335 additions & 3 deletions

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from spatialdata_plot._logging import _log_context, logger
3232
from spatialdata_plot.pl.render import (
3333
_draw_channel_legend,
34+
_render_graph,
3435
_render_images,
3536
_render_labels,
3637
_render_points,
@@ -44,6 +45,7 @@
4445
ChannelLegendEntry,
4546
CmapParams,
4647
ColorbarSpec,
48+
GraphRenderParams,
4749
ImageRenderParams,
4850
LabelsRenderParams,
4951
LegendParams,
@@ -63,6 +65,7 @@
6365
_prepare_cmap_norm,
6466
_prepare_params_plot,
6567
_set_outline,
68+
_validate_graph_render_params,
6669
_validate_image_render_params,
6770
_validate_label_render_params,
6871
_validate_points_render_params,
@@ -861,6 +864,82 @@ def render_labels(
861864
n_steps += 1
862865
return sdata
863866

867+
def render_graph(
868+
self,
869+
element: str | None = None,
870+
color: ColorLike | None = "grey",
871+
*,
872+
connectivity_key: str = "spatial",
873+
groups: list[str] | str | None = None,
874+
group_key: str | None = None,
875+
edge_width: float = 1.0,
876+
edge_alpha: float = 1.0,
877+
table_name: str | None = None,
878+
**kwargs: Any,
879+
) -> sd.SpatialData:
880+
"""Render spatial graph edges between observations.
881+
882+
Draws edges from a connectivity matrix stored in a table's ``obsp``,
883+
using centroid coordinates of the linked spatial element.
884+
885+
Parameters
886+
----------
887+
element : str | None, optional
888+
Name of the spatial element (shapes, points, or labels) whose
889+
observations the graph connects. Auto-resolved from the table
890+
if not given.
891+
color : ColorLike | None, default "grey"
892+
Edge color as a color-like value (e.g. ``"red"``, ``"#aabbcc"``).
893+
connectivity_key : str, default "spatial"
894+
Key prefix in ``table.obsp``. Tries ``obsp[key]`` first, then
895+
``obsp[f"{key}_connectivities"]``.
896+
groups : list[str] | str | None, optional
897+
Show only edges where **both** endpoints belong to the specified
898+
groups. Requires ``group_key``.
899+
group_key : str | None, optional
900+
Column in ``table.obs`` used for group filtering.
901+
edge_width : float, default 1.0
902+
Line width for edges.
903+
edge_alpha : float, default 1.0
904+
Transparency for edges (0 = invisible, 1 = opaque).
905+
table_name : str | None, optional
906+
Table containing the graph. Auto-discovered if not given.
907+
**kwargs
908+
Forwarded to :class:`matplotlib.collections.LineCollection`.
909+
910+
Returns
911+
-------
912+
sd.SpatialData
913+
Copy with rendering parameters stored in the plotting tree.
914+
"""
915+
params = _validate_graph_render_params(
916+
self._sdata,
917+
element=element,
918+
connectivity_key=connectivity_key,
919+
table_name=table_name,
920+
color=color,
921+
edge_width=edge_width,
922+
edge_alpha=edge_alpha,
923+
groups=groups,
924+
group_key=group_key,
925+
)
926+
927+
sdata = self._copy()
928+
sdata = _verify_plotting_tree(sdata)
929+
n_steps = len(sdata.plotting_tree.keys())
930+
sdata.plotting_tree[f"{n_steps + 1}_render_graph"] = GraphRenderParams(
931+
element=params["element"],
932+
connectivity_key=params["obsp_key"],
933+
table_name=params["table_name"],
934+
color=params["color"],
935+
groups=params["groups"],
936+
group_key=params["group_key"],
937+
edge_width=params["edge_width"],
938+
edge_alpha=params["edge_alpha"],
939+
zorder=n_steps,
940+
)
941+
return sdata
942+
864943
def show(
865944
self,
866945
coordinate_systems: list[str] | str | None = None,
@@ -1001,6 +1080,7 @@ def show(
10011080
"render_shapes",
10021081
"render_labels",
10031082
"render_points",
1083+
"render_graph",
10041084
]
10051085

10061086
# prepare rendering params
@@ -1311,6 +1391,23 @@ def _draw_colorbar(
13111391
rasterize=rasterize,
13121392
)
13131393

1394+
elif cmd == "render_graph":
1395+
# Graph rendering: resolve which element the graph connects,
1396+
# check if that element exists in this CS.
1397+
graph_element = params_copy.element
1398+
element_in_cs = (
1399+
(graph_element in sdata.shapes and has_shapes)
1400+
or (graph_element in sdata.points and has_points)
1401+
or (graph_element in sdata.labels and has_labels)
1402+
)
1403+
if element_in_cs:
1404+
_render_graph(
1405+
sdata=sdata,
1406+
render_params=params_copy,
1407+
coordinate_system=cs,
1408+
ax=ax,
1409+
)
1410+
13141411
if title is None:
13151412
t = cs
13161413
elif len(title) == 1:

src/spatialdata_plot/pl/render.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
Color,
5050
ColorbarSpec,
5151
FigParams,
52+
GraphRenderParams,
5253
ImageRenderParams,
5354
LabelsRenderParams,
5455
LegendParams,
@@ -1815,3 +1816,117 @@ def _draw_labels(
18151816
scalebar_units=scalebar_params.scalebar_units,
18161817
# scalebar_kwargs=scalebar_params.scalebar_kwargs,
18171818
)
1819+
1820+
1821+
def _render_graph(
1822+
sdata: sd.SpatialData,
1823+
render_params: GraphRenderParams,
1824+
coordinate_system: str,
1825+
ax: matplotlib.axes.SubplotBase,
1826+
**kwargs: Any,
1827+
) -> None:
1828+
"""Render spatial graph edges as a LineCollection on the given axes."""
1829+
from matplotlib.collections import LineCollection
1830+
from scipy.sparse import triu
1831+
1832+
_log_context.set("render_graph")
1833+
element_name = render_params.element
1834+
table_name = render_params.table_name
1835+
1836+
# Get table and adjacency matrix
1837+
table = sdata[table_name]
1838+
obsp_key = render_params.connectivity_key
1839+
# _validate_graph_render_params already resolved the actual obsp key,
1840+
# but we stored the prefix in render_params — re-resolve here
1841+
if obsp_key not in table.obsp:
1842+
suffixed = f"{obsp_key}_connectivities"
1843+
if suffixed in table.obsp:
1844+
obsp_key = suffixed
1845+
else:
1846+
logger.warning(f"Connectivity key '{obsp_key}' not found in table obsp. Skipping graph rendering.")
1847+
return
1848+
1849+
adjacency = table.obsp[obsp_key]
1850+
1851+
# Get the spatial element
1852+
if element_name in sdata.shapes:
1853+
element = sdata.shapes[element_name]
1854+
elif element_name in sdata.points:
1855+
element = sdata.points[element_name]
1856+
elif element_name in sdata.labels:
1857+
element = sdata.labels[element_name]
1858+
else:
1859+
logger.warning(f"Element '{element_name}' not found in sdata. Skipping graph rendering.")
1860+
return
1861+
1862+
# Get centroids in the target coordinate system
1863+
centroids_df = sd.get_centroids(element, coordinate_system=coordinate_system)
1864+
if hasattr(centroids_df, "compute"):
1865+
centroids_df = centroids_df.compute()
1866+
1867+
centroid_coords = np.column_stack([centroids_df["x"].values, centroids_df["y"].values])
1868+
1869+
# Align table observations to centroid positions
1870+
# The table's instance_key maps obs rows to spatial element instances.
1871+
# Centroids are ordered by element instance (e.g., label ID or GeoDataFrame index).
1872+
_, region_key, instance_key = get_table_keys(table)
1873+
1874+
# Filter table to only rows annotating this element
1875+
element_mask = table.obs[region_key] == element_name if region_key is not None else np.ones(table.n_obs, dtype=bool)
1876+
table_subset_indices = np.where(element_mask)[0]
1877+
instance_ids = table.obs[instance_key].values[element_mask]
1878+
1879+
# Build mapping from instance_id to centroid row index
1880+
# For shapes/points, centroids follow the GeoDataFrame/DataFrame index order.
1881+
# For labels, centroids follow unique label IDs (excluding background).
1882+
centroid_ids = centroids_df.index.values if hasattr(centroids_df, "index") else np.arange(len(centroids_df))
1883+
1884+
id_to_centroid_row = {}
1885+
for row, cid in enumerate(centroid_ids):
1886+
id_to_centroid_row[cid] = row
1887+
1888+
# Map each table obs (that annotates this element) to a centroid coordinate
1889+
obs_to_coord = {}
1890+
for table_row, iid in zip(table_subset_indices, instance_ids, strict=True):
1891+
if iid in id_to_centroid_row:
1892+
obs_to_coord[table_row] = centroid_coords[id_to_centroid_row[iid]]
1893+
1894+
# Apply group filtering
1895+
groups = render_params.groups
1896+
group_key = render_params.group_key
1897+
if groups is not None and group_key is not None:
1898+
group_values = table.obs[group_key].values
1899+
group_set = set(groups)
1900+
obs_in_groups = {idx for idx in obs_to_coord if group_values[idx] in group_set}
1901+
else:
1902+
obs_in_groups = set(obs_to_coord.keys())
1903+
1904+
# Extract edges from upper triangle (undirected graph — draw each edge once)
1905+
adj_upper = triu(adjacency, k=0)
1906+
rows, cols = adj_upper.nonzero()
1907+
1908+
# Build line segments for edges where both endpoints are valid
1909+
segments = []
1910+
for r, c in zip(rows, cols, strict=True):
1911+
if r == c:
1912+
continue # skip self-loops
1913+
if r in obs_in_groups and c in obs_in_groups and r in obs_to_coord and c in obs_to_coord:
1914+
segments.append([obs_to_coord[r], obs_to_coord[c]])
1915+
1916+
if not segments:
1917+
return
1918+
1919+
segments_arr = np.array(segments)
1920+
1921+
edge_color = render_params.color.get_hex() if render_params.color is not None else "#808080"
1922+
1923+
lc = LineCollection(
1924+
segments_arr,
1925+
linewidths=render_params.edge_width,
1926+
colors=edge_color,
1927+
alpha=render_params.edge_alpha,
1928+
zorder=render_params.zorder,
1929+
**kwargs,
1930+
)
1931+
lc.set_rasterized(True)
1932+
ax.add_collection(lc)

src/spatialdata_plot/pl/render_params.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,18 @@ class LabelsRenderParams:
307307
zorder: int = 0
308308
colorbar: bool | str | None = "auto"
309309
colorbar_params: dict[str, object] | None = None
310+
311+
312+
@dataclass
313+
class GraphRenderParams:
314+
"""Graph render parameters."""
315+
316+
element: str
317+
connectivity_key: str = "spatial"
318+
table_name: str | None = None
319+
color: Color | None = None
320+
groups: list[str] | str | None = None
321+
group_key: str | None = None
322+
edge_width: float = 1.0
323+
edge_alpha: float = 1.0
324+
zorder: int = 0

0 commit comments

Comments
 (0)