Skip to content

Commit be3ddf3

Browse files
timtreisclaude
andcommitted
Fix multiscale resolution selection picking wrong scale level (#589)
The heuristic in `_multiscale_to_spatial_image` used `min()` on indices from an ascending-sorted scale list, which selected *lower* resolution when x and y optimal pixel counts disagreed. It also relied on `searchsorted` over y_dims that may not be sorted after x-based reordering. Replace with a direct loop that finds the lowest-resolution scale where both x and y dimensions meet the target pixel count, falling back to the highest available resolution. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5cfedc7 commit be3ddf3

2 files changed

Lines changed: 87 additions & 12 deletions

File tree

src/spatialdata_plot/pl/utils.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,7 +2071,7 @@ def _multiscale_to_spatial_image(
20712071
# use scale with highest resolution
20722072
optimal_scale = scales[np.argmax(x_dims)]
20732073
else:
2074-
# ensure that lists are sorted
2074+
# sort scales ascending by x resolution
20752075
order = np.argsort(x_dims)
20762076
scales = [scales[i] for i in order]
20772077
x_dims = [x_dims[i] for i in order]
@@ -2080,17 +2080,13 @@ def _multiscale_to_spatial_image(
20802080
optimal_x = width * dpi
20812081
optimal_y = height * dpi
20822082

2083-
# get scale where the dimensions are close to the optimal values
2084-
# when possible, pick higher resolution (worst case: downscaled afterwards)
2085-
optimal_index_y = np.searchsorted(y_dims, optimal_y)
2086-
if optimal_index_y == len(y_dims):
2087-
optimal_index_y -= 1
2088-
optimal_index_x = np.searchsorted(x_dims, optimal_x)
2089-
if optimal_index_x == len(x_dims):
2090-
optimal_index_x -= 1
2091-
2092-
# pick the scale with higher resolution (worst case: downscaled afterwards)
2093-
optimal_scale = scales[min(int(optimal_index_x), int(optimal_index_y))]
2083+
# Pick the lowest-resolution scale where both x and y are >= the
2084+
# target pixel count. Falls back to highest available resolution.
2085+
optimal_scale = scales[-1]
2086+
for i, (xd, yd) in enumerate(zip(x_dims, y_dims, strict=True)):
2087+
if xd >= optimal_x and yd >= optimal_y:
2088+
optimal_scale = scales[i]
2089+
break
20942090

20952091
# NOTE: problematic if there are cases with > 1 data variable
20962092
data_var_keys = list(multiscale_image[optimal_scale].data_vars)

tests/pl/test_utils.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,82 @@ def test_utils_get_subplots_produces_correct_axs_layout(input_output):
283283

284284
assert len_axs == len(axs.flatten())
285285
assert axs_visible == [ax.axison for ax in axs.flatten()]
286+
287+
288+
class TestMultiscaleToSpatialImage:
289+
"""Regression tests for #589: multiscale resolution selection."""
290+
291+
@staticmethod
292+
def _make_multiscale(shape, scale_factors):
293+
from spatialdata.models import Image2DModel
294+
295+
rng = np.random.default_rng(42)
296+
return Image2DModel.parse(
297+
rng.normal(size=shape),
298+
scale_factors=scale_factors,
299+
dims=("c", "y", "x"),
300+
c_coords=["r", "g", "b"],
301+
)
302+
303+
def test_larger_figure_never_picks_lower_resolution(self):
304+
"""Increasing figure size must select equal or higher resolution."""
305+
from spatialdata_plot.pl.utils import _multiscale_to_spatial_image
306+
307+
multiscale = self._make_multiscale((3, 1024, 1024), [2, 2])
308+
dpi = 100.0
309+
prev_x = 0
310+
for size in [3, 4, 5, 6, 7, 8, 10, 12]:
311+
result = _multiscale_to_spatial_image(multiscale, dpi, float(size), float(size))
312+
cur_x = result.sizes["x"]
313+
assert cur_x >= prev_x, (
314+
f"figsize {size} selected x={cur_x} which is lower than x={prev_x} from a smaller figure"
315+
)
316+
prev_x = cur_x
317+
318+
def test_asymmetric_image_picks_sufficient_resolution(self):
319+
"""When image aspect ratio differs from figure, both axes must be covered."""
320+
from spatialdata_plot.pl.utils import _multiscale_to_spatial_image
321+
322+
multiscale = self._make_multiscale((3, 400, 1200), [2, 2])
323+
scales_info = {
324+
leaf.name: (multiscale[leaf.name].dims["x"], multiscale[leaf.name].dims["y"]) for leaf in multiscale.leaves
325+
}
326+
max_x = max(x for x, _ in scales_info.values())
327+
max_y = max(y for _, y in scales_info.values())
328+
329+
dpi = 100.0
330+
for w, h in [(5, 5), (3, 10), (10, 3), (7, 4)]:
331+
result = _multiscale_to_spatial_image(multiscale, dpi, float(w), float(h))
332+
sel_x, sel_y = result.sizes["x"], result.sizes["y"]
333+
opt_x, opt_y = w * dpi, h * dpi
334+
assert sel_x >= opt_x or sel_x == max_x, (
335+
f"figsize {w}x{h}: x={sel_x} < optimal {opt_x} and not the maximum available"
336+
)
337+
assert sel_y >= opt_y or sel_y == max_y, (
338+
f"figsize {w}x{h}: y={sel_y} < optimal {opt_y} and not the maximum available"
339+
)
340+
341+
def test_all_scales_too_small_picks_highest_resolution(self):
342+
"""When no scale is large enough, the highest resolution is selected."""
343+
from spatialdata_plot.pl.utils import _multiscale_to_spatial_image
344+
345+
multiscale = self._make_multiscale((3, 64, 64), [2, 2])
346+
result = _multiscale_to_spatial_image(multiscale, dpi=100.0, width=20.0, height=20.0)
347+
assert result.sizes["x"] == 64
348+
349+
def test_single_scale_level(self):
350+
"""A single-level multiscale image always returns that level."""
351+
from spatialdata_plot.pl.utils import _multiscale_to_spatial_image
352+
353+
multiscale = self._make_multiscale((3, 512, 512), [2])
354+
for size in [2, 5, 10]:
355+
result = _multiscale_to_spatial_image(multiscale, dpi=100.0, width=float(size), height=float(size))
356+
assert result.sizes["x"] in (512, 256)
357+
358+
def test_exact_match_selects_that_scale(self):
359+
"""When optimal pixels exactly match a scale's dimensions, that scale is selected."""
360+
from spatialdata_plot.pl.utils import _multiscale_to_spatial_image
361+
362+
multiscale = self._make_multiscale((3, 500, 500), [2, 2])
363+
result = _multiscale_to_spatial_image(multiscale, dpi=100.0, width=2.5, height=2.5)
364+
assert result.sizes["x"] == 250

0 commit comments

Comments
 (0)