diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index efd31ef3..1635bdd2 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: include: - - {os: windows-latest, python: "3.11", dask-version: "2025.2.0", name: "Dask 2025.2.0"} + - {os: windows-latest, python: "3.11", dask-version: "2025.12.0", name: "Dask 2025.12.0"} - {os: windows-latest, python: "3.13", dask-version: "latest", name: "Dask latest"} - {os: ubuntu-latest, python: "3.11", dask-version: "latest", name: "Dask latest"} - {os: ubuntu-latest, python: "3.13", dask-version: "latest", name: "Dask latest"} diff --git a/docs/index.md b/docs/index.md index 73cd1b5e..8e614fa0 100644 --- a/docs/index.md +++ b/docs/index.md @@ -14,6 +14,23 @@ SpatialData is a data framework that comprises a FAIR storage format and a colle Please see our publication {cite}`marconatoSpatialDataOpenUniversal2024` for citation and to learn more. +:::{note} +With dask >= 2025.2.0, users can get an error as described in [#1077](https://github.com/scverse/spatialdata/issues/1064). While we tried implementing fixes in SpatialData, it can be that +users perform operations on the `Points` data themselves and get this error. In order to prevent it, users can use a context manager we created. + +```python +from spatialdata import disable_dask_tune_optimization +import contextlib +... + +with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext(): + +``` + +This will disable dask graph optimization if the dataframe has more than 1 partition and otherwise keep it enabled. This solves +the problem discussed in this [dask issue](https://github.com/dask/dask/issues/12193). We are looking into an upstream fix. +::: + [//]: # "numfocus-fiscal-sponsor-attribution" spatialdata is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/). diff --git a/pyproject.toml b/pyproject.toml index fb06b861..b14cf0d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "annsel>=0.1.2", "click", "dask-image", - "dask>=2025.2.0,<2026.1.2", + "dask>=2025.12.0,<2026.1.2", "distributed<2026.1.2", "datashader", "fsspec[s3,http]", diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index bb24f04e..22d8f491 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -9,6 +9,7 @@ "transformations", "datasets", "dataloader", + "disable_dask_tune_optimization", "concatenate", "rasterize", "rasterize_bins", @@ -72,5 +73,5 @@ from spatialdata._io._utils import get_dask_backing_files from spatialdata._io.format import SpatialDataFormatType from spatialdata._io.io_zarr import read_zarr -from spatialdata._utils import get_pyramid_levels, unpad_raster +from spatialdata._utils import disable_dask_tune_optimization, get_pyramid_levels, unpad_raster from spatialdata.config import settings diff --git a/src/spatialdata/_core/operations/transform.py b/src/spatialdata/_core/operations/transform.py index e821edcf..d29af709 100644 --- a/src/spatialdata/_core/operations/transform.py +++ b/src/spatialdata/_core/operations/transform.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import itertools import warnings from functools import singledispatch @@ -17,6 +18,7 @@ from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike +from spatialdata._utils import disable_dask_tune_optimization from spatialdata.models import SpatialElement, get_axes_names, get_model from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM, get_channel_names from spatialdata.transformations._utils import _get_scale, compute_coordinates, scale_radii @@ -439,8 +441,15 @@ def _( ) axes = get_axes_names(data) arrays = [] - for ax in axes: - arrays.append(data[ax].to_dask_array(lengths=True).reshape(-1, 1)) + + # Workaround to prevent partition collaps and missing dependency problem for now. + with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext(): + for ax in axes: + # TODO We have to pass on the lengths explicitly as automatic determination with dask graph optimization + # leads to collaps of the partitions. However this causes a missing dependency problem, which for now is + # prevented by setting the optimization to False when performing this operation. + arrays.append(data[ax].to_dask_array(lengths=[len(part) for part in data.partitions]).reshape(-1, 1)) + xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(len(data)), "dim": list(axes)}) xtransformed = transformation._transform_coordinates(xdata) transformed = data.drop(columns=list(axes)).copy() diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 64dd7638..64fb43e7 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -2,6 +2,7 @@ import re import warnings from collections.abc import Callable, Generator +from contextlib import contextmanager from itertools import islice from typing import Any, TypeVar @@ -9,6 +10,7 @@ import pandas as pd from anndata import AnnData from dask import array as da +from dask import config from dask.array import Array as DaskArray from xarray import DataArray, Dataset, DataTree @@ -20,6 +22,17 @@ RT = TypeVar("RT") +@contextmanager +def disable_dask_tune_optimization() -> Generator[None, None, None]: + """Prevent dask graph optimization when performing operations on dask dataframes with npartition > 1.""" + old_setting = config.config["optimization"]["tune"]["active"] + config.set({"optimization.tune.active": False}) + try: + yield + finally: + config.set({"optimization.tune.active": old_setting}) + + def _parse_list_into_array(array: list[Number] | ArrayLike) -> ArrayLike: if isinstance(array, list): array = np.array(array) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 1bb494fb..c216703f 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -1,16 +1,18 @@ +import contextlib import math import tempfile from pathlib import Path import numpy as np import pytest +from dask import config from geopandas.testing import geom_almost_equals from xarray import DataArray, DataTree from spatialdata import transform from spatialdata._core.data_extent import are_extents_equal, get_extent from spatialdata._core.spatialdata import SpatialData -from spatialdata._utils import unpad_raster +from spatialdata._utils import disable_dask_tune_optimization, unpad_raster from spatialdata.models import Image2DModel, PointsModel, ShapesModel, get_axes_names from spatialdata.transformations.operations import ( align_elements_using_landmarks, @@ -586,6 +588,44 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa _ = full_sdata.transform_to_coordinate_system("my_space", maintain_positioning=maintain_positioning) +def test_transform_points_with_multiple_partitions(full_sdata: SpatialData, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + points_memory = full_sdata["points_0"].compute() + full_sdata["points_0"] = PointsModel.parse( + full_sdata["points_0"].repartition(npartitions=4), + transformations={"global": get_transformation(full_sdata["points_0"])}, + ) + assert points_memory.equals(full_sdata["points_0"].compute()) + + full_sdata.write(tmpdir) + + full_sdata = SpatialData.read(tmpdir) + + # This just needs to run without error + data = transform(full_sdata["points_0"], to_coordinate_system="global") + + # test that data still can be computed + data.compute() + + +@pytest.mark.parametrize( + "tune,partition", + [ + (True, None), + (False, 4), + ], +) +def test_dask_tune_contextmanager(full_sdata: SpatialData, partition: int | None, tune: bool): + if partition: + full_sdata["points_0"] = PointsModel.parse( + full_sdata["points_0"].repartition(npartitions=4), + transformations={"global": get_transformation(full_sdata["points_0"])}, + ) + + with disable_dask_tune_optimization() if full_sdata["points_0"].npartitions > 1 else contextlib.nullcontext(): + assert config.config["optimization"]["tune"]["active"] is tune + + @pytest.mark.parametrize("maintain_positioning", [True, False]) def test_transform_elements_and_entire_spatial_data_object_multi_hop( full_sdata: SpatialData, maintain_positioning: bool