diff --git a/pyproject.toml b/pyproject.toml index 03ecbae..8abba6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,13 +33,13 @@ classifiers = [# https://pypi.python.org/pypi?%3Aaction=list_classifiers dependencies = [ "adjustText", "anndata>=0.12.4", # https://github.com/scverse/anndata/issues/2166 - "bioio<2", + "bioio", "bioio-nd2", "bioio-tifffile", "centrosome", "cp-measure>=0.1.16", "dask-image", - "dask<=2025.11.0", + "dask!=2026.1.1,!=2026.1.0,!=2025.12.0", # ignores keyword arguments in da.to_zarr "decorator", "filelock", "flox", @@ -54,7 +54,7 @@ dependencies = [ "natsort", "numcodecs", "numpy", - "ome-zarr<0.12.0", + "ome-zarr", "pandas", "pint", "psutil", @@ -68,9 +68,9 @@ dependencies = [ "stardist", "statsmodels", "tensorflow", - "tifffile<=2025.5.10", + "tifffile", "xarray", - "zarr<3" + "zarr>=3" ] [project.optional-dependencies] diff --git a/requirements.doc.txt b/requirements.doc.txt index 6c56034..b75a0dd 100644 --- a/requirements.doc.txt +++ b/requirements.doc.txt @@ -1,4 +1,4 @@ -ipython==9.9.0 +ipython==9.10.0 nbsphinx==0.9.8 sphinx-copybutton==0.5.2 sphinx==9.1.0 diff --git a/requirements.txt b/requirements.txt index 9bc7858..969d52a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,15 @@ anndata==0.12.10 adjustText==1.3.0 -bioio-nd2==1.3.0 +bioio==3.3.0 +bioio-nd2==1.7.0 +bioio-ome-tiff==1.4.0 +bioio-ome-zarr==3.3.0 bioio-tifffile==1.3.0 -bioio==1.6.1 centrosome==1.3.3 cp-measure==0.1.18 cython==3.2.4 dask-image==2025.11.0 -dask==2025.11.0 +dask==2026.3.0 decorator==5.2.1 filelock==3.29.0 flox==0.11.2 @@ -20,9 +22,9 @@ kneed==0.8.6 mahotas==1.4.18 matplotlib==3.10.8 natsort==8.4.0 -numcodecs==0.15.1 +numcodecs==0.16.5 numpy==2.4.4 -ome-zarr==0.11.1 +ome-zarr==0.16.0 pandas==2.3.3 pint==0.25.3 psutil==7.2.2 @@ -36,6 +38,6 @@ shapely==2.1.2 stardist==0.9.2 statsmodels==0.14.6 tensorflow==2.21.0 -tifffile==2025.5.10 -xarray==2026.2.0 -zarr==2.18.7 +tifffile==2026.4.11 +xarray==2026.4.0 +zarr==3.1.6 diff --git a/scallops/_bioio_zarr_reader.py b/scallops/_bioio_zarr_reader.py deleted file mode 100644 index 6118c46..0000000 --- a/scallops/_bioio_zarr_reader.py +++ /dev/null @@ -1,263 +0,0 @@ -import logging -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple - -import xarray as xr -from bioio_base import constants, dimensions, exceptions, io, reader, types -from fsspec.spec import AbstractFileSystem -from ome_zarr.io import parse_url -from ome_zarr.reader import Reader as ZarrReader - -logger = logging.getLogger("scallops") - - -# Same as https://github.com/bioio-devs/bioio-ome-zarr/blob/main/bioio_ome_zarr/reader.py but fixes bug in channel names -# Also checks to see if zarr path is {zarr_path}/images/image1 with only 1 image -# See https://github.com/bioio-devs/bioio-ome-zarr/pull/22 -class ScallopsZarrReader(reader.Reader): - """The main class of each reader plugin. This class is subclass of the abstract class reader - (BaseReader) in bioio-base. - - Parameters - ---------- - image: types.PathLike - String or Path to the ZARR root - fs_kwargs: Dict[str, Any] - Ignored - """ - - _xarray_dask_data: Optional["xr.DataArray"] = None - _xarray_data: Optional["xr.DataArray"] = None - _mosaic_xarray_dask_data: Optional["xr.DataArray"] = None - _mosaic_xarray_data: Optional["xr.DataArray"] = None - _dims: Optional[dimensions.Dimensions] = None - _metadata: Optional[Any] = None - _scenes: Optional[Tuple[str, ...]] = None - _current_scene_index: int = 0 - # Do not provide default value because - # they may not need to be used by your reader (i.e. input param is an array) - _fs: "AbstractFileSystem" - _path: str - - # Required Methods - - def __init__( - self, - image: types.PathLike, - fs_kwargs: Dict[str, Any] = {}, - ): - # Expand details of provided image - self._fs, self._path = io.pathlike_to_fs( - image, - enforce_exists=False, - fs_kwargs=fs_kwargs, - ) - - # Enforce valid image - if not self._is_supported_image(self._fs, self._path): - raise exceptions.UnsupportedFileFormatError( - self.__class__.__name__, - self._path, - "Could not find a .zgroup or .zarray file at the provided path.", - ) - - self._zarr = get_zarr_reader(self._fs, self._path).zarr - self._physical_pixel_sizes: Optional[types.PhysicalPixelSizes] = None - self._channel_names: Optional[List[str]] = None - - @staticmethod - def _is_supported_image(fs: AbstractFileSystem, path: str, **kwargs: Any) -> bool: - try: - get_zarr_reader(fs, path) - return True - except AttributeError: - return False - - @classmethod - def is_supported_image( - cls, - image: types.ImageLike, - fs_kwargs: Dict[str, Any] = {}, - **kwargs: Any, - ) -> bool: - if isinstance(image, (str, Path)): - return cls._is_supported_image(None, str(image), **kwargs) - else: - return reader.Reader.is_supported_image( - cls, image, fs_kwargs=fs_kwargs, **kwargs - ) - - @property - def scenes(self) -> Tuple[str, ...]: - if self._scenes is None: - scenes = self._zarr.root_attrs["multiscales"] - - # if (each scene has a name) and (that name is unique) use name. - # otherwise generate scene names. - if all("name" in scene for scene in scenes) and ( - len({scene["name"] for scene in scenes}) == len(scenes) - ): - self._scenes = tuple(str(scene["name"]) for scene in scenes) - else: - self._scenes = tuple( - f"scene_{i}" - for i in range(len(self._zarr.root_attrs["multiscales"])) - ) - return self._scenes - - @property - def resolution_levels(self) -> Tuple[int, ...]: - """ - Returns - ------- - resolution_levels: Tuple[str, ...] - Return the available resolution levels for the current scene. - By default these are ordered from highest resolution to lowest - resolution. - """ - return tuple( - rl - for rl in range( - len( - self._zarr.root_attrs["multiscales"][self.current_scene_index][ - "datasets" - ] - ) - ) - ) - - def _read_delayed(self) -> xr.DataArray: - return self._xarr_format(delayed=True) - - def _read_immediate(self) -> xr.DataArray: - return self._xarr_format(delayed=False) - - def _xarr_format(self, delayed: bool) -> xr.DataArray: - data_path = self._zarr.root_attrs["multiscales"][self.current_scene_index][ - "datasets" - ][self.current_resolution_level]["path"] - image_data = self._zarr.load(data_path) - - axes = self._zarr.root_attrs["multiscales"][self.current_scene_index].get( - "axes" - ) - if axes: - dims = [sub["name"].upper() for sub in axes] - else: - dims = list(reader.Reader._guess_dim_order(image_data.shape)) - - if not delayed: - image_data = image_data.compute() - - coords = self._get_coords( - dims, - image_data.shape, - scene=self.current_scene, - channel_names=self.channel_names, - ) - - return xr.DataArray( - image_data, - dims=dims, - coords=coords, - attrs={constants.METADATA_UNPROCESSED: self._zarr.root_attrs}, - ) - - # Optional Methods - @property - def physical_pixel_sizes(self) -> types.PhysicalPixelSizes: - """Return the physical pixel sizes of the image.""" - if self._physical_pixel_sizes is None: - try: - z_size, y_size, x_size = self._get_pixel_size( - list(self.dims.order), - ) - except Exception as e: - logger.warning(f"Could not parse zarr pixel size: {e}") - z_size, y_size, x_size = None, None, None - - self._physical_pixel_sizes = types.PhysicalPixelSizes( - z_size, y_size, x_size - ) - return self._physical_pixel_sizes - - def _get_pixel_size( - self, - dims: List[str], - ) -> Tuple[Optional[float], Optional[float], Optional[float]]: - # OmeZarr file may contain an additional set of "coordinateTransformations" - # these coefficents are applied to all resolution levels. - if ( - "coordinateTransformations" - in self._zarr.root_attrs["multiscales"][self.current_scene_index] - ): - universal_res_consts = self._zarr.root_attrs["multiscales"][ - self.current_scene_index - ]["coordinateTransformations"][0]["scale"] - else: - universal_res_consts = [1.0 for _ in range(len(dims))] - - coord_transform = self._zarr.root_attrs["multiscales"][ - self.current_scene_index - ]["datasets"][self.current_resolution_level]["coordinateTransformations"] - - spatial_coeffs = {} - - for dim in [ - dimensions.DimensionNames.SpatialX, - dimensions.DimensionNames.SpatialY, - dimensions.DimensionNames.SpatialZ, - ]: - if dim in dims: - dim_index = dims.index(dim) - spatial_coeffs[dim] = ( - coord_transform[0]["scale"][dim_index] - * universal_res_consts[dim_index] - ) - else: - spatial_coeffs[dim] = None - - return ( - spatial_coeffs[dimensions.DimensionNames.SpatialZ], - spatial_coeffs[dimensions.DimensionNames.SpatialY], - spatial_coeffs[dimensions.DimensionNames.SpatialX], - ) - - @property - def channel_names(self) -> Optional[List[str]]: - if self._channel_names is None: - if "omero" in self._zarr.root_attrs: - self._channel_names = [ - str(channel["label"]) - for channel in self._zarr.root_attrs["omero"]["channels"] - ] - return self._channel_names - - @staticmethod - def _get_coords( - dims: List[str], - shape: Tuple[int, ...], - scene: str, - channel_names: Optional[List[str]], - ) -> Dict[str, Any]: - coords: Dict[str, Any] = {} - - # Use dims for coord determination - if dimensions.DimensionNames.Channel in dims: - # Generate channel names if no existing channel names - if channel_names is None: - coords[dimensions.DimensionNames.Channel] = [ - f"channel_{i}" - for i in range(shape[dims.index(dimensions.DimensionNames.Channel)]) - ] - else: - coords[dimensions.DimensionNames.Channel] = channel_names - - return coords - - -def get_zarr_reader(fs: AbstractFileSystem, path: str) -> ZarrReader: - if fs is not None: - path = fs.unstrip_protocol(path) - - return ZarrReader(parse_url(path, mode="r")) diff --git a/scallops/cli/pooled_if_sbs.py b/scallops/cli/pooled_if_sbs.py index 3522679..0ea00df 100644 --- a/scallops/cli/pooled_if_sbs.py +++ b/scallops/cli/pooled_if_sbs.py @@ -25,7 +25,7 @@ import pyarrow.parquet as pq import xarray as xr import zarr -from dask.delayed import Delayed, delayed +from dask.delayed import delayed from matplotlib import pyplot as plt from skimage.segmentation import expand_labels @@ -75,6 +75,7 @@ from scallops.zarr_io import ( _get_fs, _get_sep, + _get_store_path, _write_zarr_image, is_anndata_zarr, open_ome_zarr, @@ -178,8 +179,7 @@ def _peaks_to_bases( def spot_detection_pipeline( image_tuple: tuple[tuple[str, ...], list[str], dict], iss_channels: list[int], - file_separator: str, - root: zarr.Group | str, + output: str, max_filter_width: int, sigma_log: float | list[float], z_index: int | str, @@ -196,7 +196,7 @@ def spot_detection_pipeline( spot_detection_method: Literal["log", "spotiflow", "u-fish", "piscis"] = "log", spot_detection_n_cycles: int | None = None, expected_cycles: int | None = None, -) -> list[Delayed]: +): """Run the spot detection pipeline. This function processes a set of images, performs spot detection, and saves the @@ -204,8 +204,7 @@ def spot_detection_pipeline( :param image_tuple: A tuple containing information about the images. :param iss_channels: List of channel indices used for ISS sequencing. - :param file_separator: Separator used in file paths. - :param root: Root path or zarr group where the results will be stored. + :param output: Root path to where the results will be stored. :param max_filter_width: Maximum filter width used in spot detection. :param z_index: Either 'max' or z-index :param sigma_log: Sigma parameter for log transformation in spot detection. @@ -225,15 +224,19 @@ def spot_detection_pipeline( """ _, file_list, metadata = image_tuple image_key = metadata["id"] + output_fs = fsspec.url_to_fs(output)[0] + output_sep = output_fs.sep + output = output.rstrip(output_sep) + points_path = f"{output}{output_sep}points" + + points_protocol = _get_fs_protocol(output_fs) + if points_protocol != "file": + points_path = f"{points_protocol}://{points_path}" + peaks_path = f"{points_path}{output_sep}{image_key}-peaks.parquet" if not force: - points_path = f"{root.store.path.rstrip(_get_sep(root))}{_get_sep(root)}points" - points_protocol = _get_fs_protocol(_get_fs(root)) - if points_protocol != "file": - points_path = f"{points_protocol}://{points_path}" - peaks_path = f"{points_path}{_get_sep(root)}{image_key}-peaks.parquet" if is_parquet_file(peaks_path): logger.info(f"Skipping spot detection for {image_key}") - return [] + return image = _images2fov(file_list, metadata, dask=True) image = _z_projection(image, z_index) if expected_cycles is not None: @@ -293,10 +296,9 @@ def spot_detection_pipeline( dask_delayed.append( _write_image( name=f"{image_key}-log", - root=root, + root=open_ome_zarr(output, mode="a"), image=loged, output_format=output_image_format, - file_separator=file_separator, zarr_format="zarr", compute=compute, ) @@ -308,10 +310,9 @@ def spot_detection_pipeline( dask_delayed.append( _write_image( name=f"{image_key}-std", - root=root, + root=open_ome_zarr(output, mode="a"), image=std_arr, output_format=output_image_format, - file_separator=file_separator, metadata=dict(parent=image_key), compute=compute, ) @@ -323,10 +324,9 @@ def spot_detection_pipeline( dask_delayed.append( _write_image( name=f"{image_key}-max", - root=root, + root=open_ome_zarr(output, mode="a"), image=maxed, output_format=output_image_format, - file_separator=file_separator, zarr_format="zarr", compute=compute, ) @@ -334,14 +334,10 @@ def spot_detection_pipeline( else: del maxed if "peaks" in save_keys: - points_path = f"{root.store.path.rstrip(_get_sep(root))}{_get_sep(root)}points" - protocol = _get_fs_protocol(_get_fs(root)) - if protocol != "file": - points_path = f"{protocol}://{points_path}" - _get_fs(root).makedirs(points_path, exist_ok=True) - peaks_path = f"{points_path}{_get_sep(root)}{image_key}-peaks.parquet" - if _get_fs(root).exists(peaks_path): - _get_fs(root).rm(peaks_path, recursive=True) + output_fs.makedirs(points_path, exist_ok=True) + + if output_fs.exists(peaks_path): + output_fs.rm(peaks_path, recursive=True) dask_delayed.append( _to_parquet( @@ -353,7 +349,6 @@ def spot_detection_pipeline( ) if not compute and len(dask_delayed) > 0: dask.compute(*dask_delayed) - return [] def _fix_cycles(sbs_cycles): @@ -803,19 +798,17 @@ def spot_detect_main(arguments: argparse.Namespace): chunks = (chunks, chunks) output = _add_suffix(output, ".zarr") - root = open_ome_zarr(output, mode="a") + exp_gen = _set_up_experiment(images, image_pattern, group_by, subset=subset) with ( _create_default_dask_config(), _create_dask_client(dask_scheduler_url, **dask_cluster_parameters), ): - delayed_results = [] for img in exp_gen: - delayed_results += spot_detection_pipeline( + spot_detection_pipeline( img, iss_channels=channels, - file_separator=None, - root=root, + output=output, z_index=z_index, output_image_format="zarr", max_filter_width=max_filter_width, @@ -833,8 +826,6 @@ def spot_detect_main(arguments: argparse.Namespace): spot_detection_n_cycles=spot_detection_n_cycles, expected_cycles=expected_cycles, ) - if len(delayed_results) > 0: - dask.compute(*delayed_results) def reads_pipeline( @@ -911,7 +902,7 @@ def reads_pipeline( logger.info(f"Running reads for {image_key}") spots_sep = _get_sep(spots_root) - points_path = f"{spots_root.store.path.rstrip(spots_sep)}{spots_sep}points" + points_path = f"{_get_store_path(spots_root).rstrip(spots_sep)}{spots_sep}points" spots_protocol = _get_fs_protocol(_get_fs(spots_root)) if spots_protocol != "file": points_path = f"{spots_protocol}://{points_path}" @@ -1229,8 +1220,8 @@ def reads_main(arguments: argparse.Namespace): for key in image_keys: reads_pipeline( key, - spots_root=zarr.open(spots, "r"), - labels_root=zarr.open(labels + labels_fs.sep + "labels", "r"), + spots_root=zarr.open(spots, mode="r"), + labels_root=zarr.open(labels + labels_fs.sep + "labels", mode="r"), barcodes_file=barcodes_file, file_separator=output_fs.sep, threshold_peaks=threshold_peaks, diff --git a/scallops/cli/register.py b/scallops/cli/register.py index ed081d9..bf9e89f 100644 --- a/scallops/cli/register.py +++ b/scallops/cli/register.py @@ -464,7 +464,7 @@ def get_matching_names( results = [] for path in paths: name = os.path.basename(path) - if not name.startswith(".") and is_ome_zarr_array(zarr.open(path, "r")): + if not name.startswith(".") and is_ome_zarr_array(zarr.open(path, mode="r")): results.append(path) return results diff --git a/scallops/cli/util.py b/scallops/cli/util.py index 0d2da08..6782e78 100644 --- a/scallops/cli/util.py +++ b/scallops/cli/util.py @@ -199,7 +199,7 @@ def _write_image( root: zarr.Group | str, image: np.ndarray | xr.DataArray | da.Array, output_format: str, - file_separator: str, + file_separator: str = "/", metadata: dict | None = None, compute: bool = True, **kwargs, diff --git a/scallops/features/generate.py b/scallops/features/generate.py index 1dd2e4e..1c389b1 100644 --- a/scallops/features/generate.py +++ b/scallops/features/generate.py @@ -158,7 +158,7 @@ def label_features( if isinstance(intensity_image, da.Array): # y,x,c assert intensity_image.shape[:-1] == label_shape, ( - f"{intensity_image.shape} != {label_shape}" + f"{intensity_image.shape[:-1]} != {label_shape}" ) label_image = label_image.rechunk(intensity_image.chunksize[:-1]) diff --git a/scallops/features/image_quality.py b/scallops/features/image_quality.py index 43a9d69..5548d9f 100644 --- a/scallops/features/image_quality.py +++ b/scallops/features/image_quality.py @@ -4,7 +4,7 @@ from scipy.ndimage import sum as nd_sum -# updated for numpy2, see https://github.com/CellProfiler/centrosome/pull/135 +# copied from centrosome.radial_power_spectrum but use np.ptp instead of img.ptp for numpy 2 def rps(img): assert img.ndim == 2 radii2 = (np.arange(img.shape[0]).reshape((img.shape[0], 1)) ** 2) + ( @@ -25,6 +25,7 @@ def rps(img): magsum = nd_sum(mag, radii, labels) powersum = nd_sum(power, radii, labels) return np.array(labels), np.array(magsum), np.array(powersum) + return [2], [0], [0] diff --git a/scallops/io.py b/scallops/io.py index 72c69d3..8767edc 100644 --- a/scallops/io.py +++ b/scallops/io.py @@ -29,6 +29,7 @@ import anndata import bioio +import bioio_ome_zarr import bioio_tifffile import dask import dask.array as da @@ -54,12 +55,11 @@ from xarray.core.utils import equivalent from zarr.storage import StoreLike -from scallops._bioio_zarr_reader import ScallopsZarrReader from scallops.experiment.elements import Experiment, _LazyLoadData from scallops.externals.tifffile2014 import imsave from scallops.utils import forceTCZYX, mlcs from scallops.xr import _crop -from scallops.zarr_io import _read_zarr_experiment, read_ome_zarr_array +from scallops.zarr_io import _get_store_path, _read_zarr_experiment, read_ome_zarr_array logger = logging.getLogger("scallops") @@ -234,7 +234,7 @@ def _create_image(path: str, **kwargs) -> bioio.BioImage: base_path_lc, ext = os.path.splitext(path_lc) if "reader" not in img_args: if ext in ["", ".zarr", "/", ".zarr/"]: - img_args["reader"] = ScallopsZarrReader + img_args["reader"] = bioio_ome_zarr.Reader elif ext in [".tiff", ".tif"] and os.path.splitext(base_path_lc)[1] != ".ome": img_args["reader"] = bioio_tifffile.Reader return bioio.BioImage(path, **img_args) @@ -1358,7 +1358,7 @@ def _images2fov( name = ( os.path.basename(file_list[i]) if not isinstance(file_list[i], zarr.Group) - else file_list[i].store.path + else _get_store_path(file_list[i]) ) src_metadata.append(dict(attrs=image_attrs[i], name=name)) @@ -1599,6 +1599,7 @@ def _get_image_key_func(group_by): lambda: [] ) # key is tuple -> value is tuple of group, dict maxdepth = None + for image_path in image_paths: if isinstance(image_path, Path): # IF URI DO NOT PROVIDE AS PATH @@ -1611,6 +1612,7 @@ def _get_image_key_func(group_by): pass else: root = image_path + if root is not None: if "0" not in root: # format: "path.zarr/images/" if "images" in root: @@ -1664,6 +1666,7 @@ def _get_image_key_func(group_by): if image_path in [".", "./"] and _get_fs_protocol(fs) == "file": image_path = fs.info(image_path)["name"].rstrip(".") image_prefix = None + if fs.isdir(image_path): image_path = image_path.rstrip(fs.sep) if maxdepth is None: @@ -1681,6 +1684,7 @@ def _get_image_key_func(group_by): withdirs=True, ) ) + paths = [p for p in all_paths if p.lower().endswith(extension)] if len(paths) == 0: # try with no maxdepth @@ -1718,7 +1722,7 @@ def _get_image_key_func(group_by): group_to_matches[group].append((x, d)) if len(group_to_matches) == 0: - message = [f"No files found matching pattern: {file_regex.pattern}"] + message = [f"No files found matching pattern: {files_pattern}"] if subset_ is not None: message.append(f", subset: {', '.join([str(s) for s in subset_])}") if len(group_by) > 0: @@ -1784,7 +1788,9 @@ def file_sort_key(x): src=file_list, common_src=mlcs( [ - Path(x).stem if not isinstance(x, zarr.Group) else x.store.path + Path(x).stem + if not isinstance(x, zarr.Group) + else _get_store_path(x) for x in file_list ] ), diff --git a/scallops/registration/itk.py b/scallops/registration/itk.py index 205c15b..5dbbef1 100644 --- a/scallops/registration/itk.py +++ b/scallops/registration/itk.py @@ -33,7 +33,12 @@ from scallops.registration.landmarks import _get_translation, find_landmarks from scallops.utils import _dask_from_array_no_copy from scallops.xr import _get_dims -from scallops.zarr_io import open_ome_zarr, write_zarr +from scallops.zarr_io import ( + default_zarr_format, + get_zarr_array_kwargs, + open_ome_zarr, + write_zarr, +) logger = logging.getLogger("scallops") @@ -328,15 +333,18 @@ def _init_callback(init_params: dict[str, Any]) -> dict[str, Any]: group = None if image_root is not None: images_group = image_root.require_group("images", overwrite=False) + fmt = default_zarr_format() group = images_group.create_group( image_name.replace("/", "-"), overwrite=True ) - zarr_dataset = group.create_dataset( + + zarr_dataset = group.create_array( "0", shape=shape, chunks=(1,) * (len(shape) - 2) + chunk_size, dtype=dtype, overwrite=True, + **get_zarr_array_kwargs(fmt), ) return { @@ -1164,12 +1172,15 @@ def _itk_transform_image_zarr( image_name.replace("/", "-"), overwrite=True ) chunks = (1,) * len(transform_dims) + (chunksize or (1024, 1024)) - data = group.create_dataset( + fmt = default_zarr_format() + + data = group.create_array( "0", shape=dim_sizes + output_size, chunks=chunks, dtype=image.dtype, overwrite=True, + **get_zarr_array_kwargs(fmt), ) _itk_transform_image( diff --git a/scallops/stitch/_stitch.py b/scallops/stitch/_stitch.py index 212c820..cee63c0 100644 --- a/scallops/stitch/_stitch.py +++ b/scallops/stitch/_stitch.py @@ -6,7 +6,6 @@ from collections.abc import Sequence from typing import Literal -import dask.array as da import fsspec import numpy as np import pandas as pd @@ -14,7 +13,6 @@ import pyarrow.parquet as pq import zarr from sklearn.cluster import AgglomerativeClustering -from zarr.errors import PathNotFoundError from scallops.cli.util import _get_cli_logger, cli_metadata from scallops.io import is_parquet_file, read_image @@ -32,7 +30,7 @@ tile_source_labels, ) from scallops.utils import _dask_from_array_no_copy -from scallops.zarr_io import is_ome_zarr_array +from scallops.zarr_io import is_ome_zarr_array, write_zarr logger = _get_cli_logger() @@ -82,14 +80,14 @@ def _single_stitch( if is_ome_zarr_array(image_output_root.get(f"images/{image_key}")): logger.info(f"Skipping stitching for {image_key}.") return - except PathNotFoundError: + except: # noqa: E722 pass elif not no_save_labels: try: if is_ome_zarr_array(image_output_root.get(f"labels/{image_key}-mask")): logger.info(f"Skipping stitching for {image_key}.") return - except PathNotFoundError: + except: # noqa: E722 pass elif is_parquet_file(f"{other_output_path}{image_key}-positions.parquet"): logger.info(f"Skipping stitching for {image_key}.") @@ -344,10 +342,10 @@ def _single_stitch( tile_shape_no_crop[0] - fuse_crop_width[0] * 2, tile_shape_no_crop[1] - fuse_crop_width[1] * 2, ) - fused_y_size = ( + fused_y_size = int( np.round(stitch_positions_df["y"].max()).astype(int) + fused_tile_shape[0] ) - fused_x_size = ( + fused_x_size = int( np.round(stitch_positions_df["x"].max()).astype(int) + fused_tile_shape[1] ) @@ -361,8 +359,6 @@ def _single_stitch( blend, image_output_root, image_key, - fused_y_size, - fused_x_size, fused_tile_shape, chunk_size, image_spacing, @@ -387,8 +383,6 @@ def _write_arrays( blend, image_output_root, image_key, - fused_y_size, - fused_x_size, fused_tile_shape, chunk_size, image_spacing, @@ -408,17 +402,8 @@ def _write_arrays( labels_group = image_output_root.require_group("labels") group = labels_group.create_group(image_key + "-mask", overwrite=True) - array = group.create_dataset( - name="0", - shape=(fused_y_size, fused_x_size), - chunks=chunk_size, - dtype=np.uint8, - dimension_separator="/", - overwrite=True, - ) - - da.to_zarr( - arr=_dask_from_array_no_copy( + write_zarr( + data=_dask_from_array_no_copy( tile_overlap_mask( stitch_positions_df, fill=blend != "none", @@ -426,39 +411,39 @@ def _write_arrays( ), chunks=chunk_size, ), - url=array, + grp=group, + image_attrs=None, + coords=None, + dims=None, + scaler=None, compute=True, - dimension_separator="/", ) group.attrs.update( _create_label_ome_metadata(image_spacing, image_key + "-mask") ) if blend == "none": group = labels_group.create_group(image_key + "-tile", overwrite=True) - array = group.create_dataset( - name="0", - shape=(fused_y_size, fused_x_size), - chunks=chunk_size, - dtype=np.uint16, - dimension_separator="/", - overwrite=True, - ) - - da.to_zarr( - arr=_dask_from_array_no_copy( + write_zarr( + data=_dask_from_array_no_copy( tile_source_labels(stitch_positions_df, fused_tile_shape), chunks=chunk_size, ), - url=array, + grp=group, + image_attrs=None, + coords=None, + dims=None, + scaler=None, compute=True, - dimension_separator="/", ) label_metadata = _create_label_ome_metadata( image_spacing, image_key + "-tile" ) - label_metadata["multiscales"][0]["metadata"] = { - "source": f"../../images/{image_key}" - } + label_multiscales = ( + label_metadata["ome"]["multiscales"] + if "ome" in label_metadata + else label_metadata["multiscales"] + ) + label_multiscales[0]["metadata"] = {"source": f"../../images/{image_key}"} group.attrs.update(label_metadata) cleanup_paths = [] if not no_save_image: diff --git a/scallops/stitch/fuse.py b/scallops/stitch/fuse.py index 640c27f..f66e02f 100644 --- a/scallops/stitch/fuse.py +++ b/scallops/stitch/fuse.py @@ -22,12 +22,14 @@ from scallops.stitch._radial import radial_correct from scallops.stitch.utils import _crop_image, dtype_convert from scallops.utils import _cpu_count, _dask_from_array_no_copy +from scallops.zarr_io import default_zarr_format, get_zarr_array_kwargs logger = logging.getLogger("scallops") def _create_label_ome_metadata(image_spacing: tuple[float, float], label_name: str): - return { + fmt = default_zarr_format() + d = { "multiscales": [ { "axes": [ @@ -38,10 +40,10 @@ def _create_label_ome_metadata(image_spacing: tuple[float, float], label_name: s { "coordinateTransformations": [ { - "scale": [ + "scale": ( float(image_spacing[0]), float(image_spacing[1]), - ], + ), "type": "scale", } ], @@ -49,10 +51,14 @@ def _create_label_ome_metadata(image_spacing: tuple[float, float], label_name: s } ], "name": f"/labels/{label_name}", - "version": "0.4", + "version": fmt.version, } ] } + if fmt.version in ("0.1", "0.2", "0.3", "0.4"): + return d + + return {"ome": d} def _create_ome_metadata( @@ -64,9 +70,10 @@ def _create_ome_metadata( metadata = {} metadata.update(**kwargs) metadata["stitch_coords"] = dict() + fmt = default_zarr_format() for c in stitch_coords: # convert to dict metadata["stitch_coords"][c] = stitch_coords[c].to_list() - return { + d = { "multiscales": [ { "metadata": metadata, @@ -79,11 +86,11 @@ def _create_ome_metadata( { "coordinateTransformations": [ { - "scale": [ + "scale": ( 1.0, float(image_spacing[0]), float(image_spacing[1]), - ], + ), "type": "scale", } ], @@ -91,10 +98,13 @@ def _create_ome_metadata( } ], "name": f"/images/{image_key}", - "version": "0.4", + "version": fmt.version, } ] } + if fmt.version in ("0.1", "0.2", "0.3", "0.4"): + return d + return {"ome": d} def _fuse( @@ -173,8 +183,8 @@ def _fuse( df["x"] = df["x"].round().values.astype(int) df["y"] = df["y"].round().values.astype(int) - fused_y_size = (df["y"] + ysize).max() - fused_x_size = (df["x"] + xsize).max() + fused_y_size = int((df["y"] + ysize).max()) + fused_x_size = int((df["x"] + xsize).max()) if channels_per_batch is None: if blend == "none": @@ -221,18 +231,16 @@ def _fuse( locks.append(threading.Lock()) locks = np.array(locks) partition_tree = shapely.STRtree(partition_boxes) + output_shape = (len(output_channels), fused_y_size, fused_x_size) + fmt = default_zarr_format() - result = group.create_dataset( - shape=( - len(output_channels), # c - fused_y_size, - fused_x_size, - ), + result = group.create_array( + shape=output_shape, dtype=target_dtype, chunks=(1,) + chunk_size, name="0", - dimension_separator="/", overwrite=True, + **get_zarr_array_kwargs(fmt), ) _fuse_image_delayed = delayed(_fuse_image) @@ -372,7 +380,6 @@ def _fuse( url=result, region=(slice(channel_batch, channel_batch + channels_per_batch),), compute=True, - dimension_separator="/", ) diff --git a/scallops/tests/test_features.py b/scallops/tests/test_features.py index ec71ed9..41ab6a7 100644 --- a/scallops/tests/test_features.py +++ b/scallops/tests/test_features.py @@ -60,12 +60,14 @@ def test_to_label_crops(tmp_path, array_A1_102_cells, array_A1_102_alnpheno): assert len(result_df) == 1 and result_df.index.values[0] == 2603 group = zarr.group() - intensity_image_zarr = group.create_dataset( - name="image", shape=intensity_image.shape + intensity_image_zarr = group.create_array( + name="image", shape=intensity_image.shape, dtype=intensity_image.dtype ) intensity_image_zarr[:] = intensity_image.compute() - label_image_zarr = group.create_dataset(name="label", shape=label_image.shape) + label_image_zarr = group.create_array( + name="label", shape=label_image.shape, dtype=label_image.dtype + ) label_image_zarr[:] = label_image.compute() to_label_crops( diff --git a/scallops/tests/test_illumination_correction.py b/scallops/tests/test_illumination_correction.py index d997bc4..3c6c12f 100644 --- a/scallops/tests/test_illumination_correction.py +++ b/scallops/tests/test_illumination_correction.py @@ -28,8 +28,8 @@ def test_illumination_correction_cli(tmp_path): ] subprocess.check_call(args) - store = zarr.ZipStore("scallops/tests/data/ops-illum-corr.zip", mode="r") - root = zarr.group(store=store) + store = zarr.storage.ZipStore("scallops/tests/data/ops-illum-corr.zip", mode="r") + root = zarr.open(store=store, mode="r") np.testing.assert_equal( root["data"][...], read_image(os.path.join(tmp_path, "images", "A1")).values.squeeze(), diff --git a/scallops/tests/test_io.py b/scallops/tests/test_io.py index 05f0985..91c910a 100644 --- a/scallops/tests/test_io.py +++ b/scallops/tests/test_io.py @@ -217,12 +217,14 @@ def test_write_non_ome_zarr_image(tmp_path, dask): image.attrs["physical_pixel_sizes"] = (1, 1, 1) image.attrs["physical_pixel_units"] = ("mm", "mm", "mm") zarr_path = str(tmp_path / "test.zarr") - _write_zarr_image("foo", open_ome_zarr(zarr_path), image, zarr_format="zarr") - _write_zarr_image("foo2", open_ome_zarr(zarr_path), image) + _write_zarr_image("img_zarr", open_ome_zarr(zarr_path), image, zarr_format="zarr") + _write_zarr_image("img_ome_zarr", open_ome_zarr(zarr_path), image) + + data_zarr = read_image(f"{zarr_path}/images/img_zarr", dask=False) + data_ome_zarr = read_image(f"{zarr_path}/images/img_ome_zarr", dask=False) - data_zarr = read_image(f"{zarr_path}/images/foo", dask=False) - data_ome_zarr = read_image(f"{zarr_path}/images/foo2", dask=False) xr.testing.assert_equal(data_zarr, data_ome_zarr) + xr.testing.assert_equal(image, data_ome_zarr) @pytest.mark.io @@ -344,7 +346,7 @@ def test_read_write_labels(tmp_path, array_A1_102_nuclei): _write_zarr_labels( name="test", root=open_ome_zarr(str(tmp_path), "w"), labels=nuclei ) - test = read_ome_zarr_array(zarr.open(str(tmp_path / "labels" / "test"), "r")) + test = read_ome_zarr_array(zarr.open(str(tmp_path / "labels" / "test"), mode="r")) np.testing.assert_equal(nuclei, test.data) diff --git a/scallops/zarr_io.py b/scallops/zarr_io.py index 035786c..197ac0c 100644 --- a/scallops/zarr_io.py +++ b/scallops/zarr_io.py @@ -24,7 +24,7 @@ from dask.delayed import Delayed from dask.graph_manipulation import bind from ome_zarr.axes import KNOWN_AXES -from ome_zarr.format import CurrentFormat +from ome_zarr.format import FormatV04 from ome_zarr.io import parse_url from ome_zarr.scale import Scaler from ome_zarr.types import JSONDict @@ -38,6 +38,18 @@ logger = logging.getLogger("scallops") +def default_zarr_format(): + return FormatV04() + + +def get_zarr_array_kwargs(fmt): + return ( + {"dimension_separator": "/"} + if fmt.version == 2 + else {"chunk_key_encoding": fmt.chunk_key_encoding} + ) + + def is_anndata_zarr(store: StoreLike) -> bool: """Determines whether store is an AnnData Zarr . @@ -76,13 +88,21 @@ def is_ome_zarr_array(node: zarr.Group) -> bool: result = is_ome_zarr_array(root) print(result) # Output: True """ - return node is not None and "multiscales" in node.attrs + return node is not None and ("ome" in node.attrs or "multiscales" in node.attrs) def _get_fs(group: zarr.Group): if hasattr(group.store, "fs"): return group.store.fs - return fsspec.url_to_fs(group.store.path)[0] + return fsspec.url_to_fs(_get_store_path(group))[0] + + +def _get_store_path(group: zarr.Group): + if hasattr(group.store, "root"): + return str(group.store.root) + if hasattr(group.store, "path"): + return group.store.path + return "" def _get_sep(group: zarr.Group) -> str: @@ -134,7 +154,7 @@ def _create_omero_metadata( # Napari requires that colors are specified if channel names are specified channels = ( [ - dict(label=channel_names[i], color=colors[i % len(colors)]) + dict(label=str(channel_names[i]), color=colors[i % len(colors)]) for i in range(len(channel_names)) ] if not np.isscalar(channel_names) @@ -181,7 +201,7 @@ def _fix_attrs(d: dict) -> None: elif isinstance(value, ome_types.OME): # Hack to prevent OverflowError: # Overlong 4 byte UTF-8 sequence detected when encoding string - d[key] = d[key].dict() + d[key] = d[key].model_dump(mode="json") elif isinstance(value, zarr.Group): d[key] = str(value) elif isinstance(value, list): @@ -189,7 +209,7 @@ def _fix_attrs(d: dict) -> None: if isinstance(value[i], dict): _fix_attrs(value[i]) elif isinstance(value[i], ome_types.OME): - value[i] = value[i].dict() + value[i] = value[i].model_dump(mode="json") elif isinstance(value[i], zarr.Group): value[i] = str(value) @@ -210,33 +230,8 @@ def _attrs_axes_coordinates( - Updated image attributes dictionary. - List of axes dictionaries. - List of coordinate transformations dictionaries or None. - - :example: - - .. code-block:: python - - import xarray as xr - import numpy as np - from scallops.zarr_io import _attrs_axes_coordinates - - data = np.random.rand(5, 10, 512, 512) - dims = ("c", "z", "y", "x") - coords = {"c": ["DAPI", "FITC", "TRITC", "Cy5", "Cy7"]} - array = xr.DataArray(data, dims=dims, coords=coords) - image_attrs = { - "physical_pixel_sizes": [0.1, 0.1, 0.5], - "physical_pixel_units": ["um", "um", "um"], - } - - # Prepare attributes, axes, and coordinate transformations - updated_attrs, axes, coord_transformations = _attrs_axes_coordinates( - image_attrs, array.coords, array.dims - ) - print(updated_attrs) - print(axes) - print(coord_transformations) """ - image_attrs = _fix_json(image_attrs) + omero = _create_omero_metadata(coords, dims) if omero is not None: image_attrs["omero"] = omero @@ -269,7 +264,9 @@ def _attrs_axes_coordinates( axis["unit"] = physical_pixel_units[space_index] space_index = space_index + 1 axes.append(axis) - + image_attrs = image_attrs.copy() + _fix_attrs(image_attrs) + image_attrs = _fix_json(image_attrs) return image_attrs, axes, coordinate_transformations @@ -404,49 +401,73 @@ def write_zarr( if image_attrs is not None: # Metadata can't be numpy arrays or python classes so do a round trip # conversion to convert to JSON serializable - _fix_attrs(image_attrs) + if metadata is not None: image_attrs.update(metadata) + image_attrs, axes, coordinate_transformations = _attrs_axes_coordinates( image_attrs, coords, dims ) + dask_delayed = [] + fmt = default_zarr_format() if zarr_format == "zarr": # No axis validation + zarr_array_kwargs = get_zarr_array_kwargs(fmt) if isinstance(data, da.Array): d = da.to_zarr( arr=data, url=grp.store, component=str(Path(grp.path, "0")), compute=compute, - dimension_separator=grp._store._dimension_separator, + zarr_array_kwargs=zarr_array_kwargs, ) if not compute: dask_delayed.append(d) elif not isinstance(data, zarr.Array): - grp.create_dataset("0", data=data, overwrite=True) - + grp.create_array("0", data=data, overwrite=True, **zarr_array_kwargs) + # v3 + # ome/omero for channel metadata + # ome/multiscales[0]/metadata for other metadata + + # v2: + # omero for channel metadata + # multiscales[0]/metadata for other metadata datasets = [{"path": "0"}] if coordinate_transformations is not None: datasets[0]["coordinateTransformations"] = coordinate_transformations - multiscales = [ - dict(version=CurrentFormat().version, datasets=datasets, name=grp.name) - ] - d = {"multiscales": multiscales} + multiscales = [dict(version=fmt.version, datasets=datasets, name=grp.name)] + zarr_attrs = ( + {"multiscales": multiscales} + if fmt.zarr_format == 2 + else {"ome": {"multiscales": multiscales}} + ) + if axes is not None: multiscales[0]["axes"] = axes if image_attrs is not None: - multiscales[0]["metadata"] = image_attrs if "omero" in image_attrs: - d["omero"] = image_attrs["omero"] + if fmt.zarr_format == 2: + omero = zarr_attrs.get("omero", {}) + omero.update(image_attrs.pop("omero")) + zarr_attrs["omero"] = omero + else: + omero = zarr_attrs["ome"].get("omero", {}) + omero.update(image_attrs.pop("omero")) + zarr_attrs["ome"]["omero"] = omero + + multiscales[0]["metadata"] = image_attrs + if len(dask_delayed) > 0: @dask.delayed def _write_metadata_delayed(grp, d): grp.attrs.update(d) - return dask_delayed + [bind(_write_metadata_delayed, dask_delayed)(grp, d)] + return dask_delayed + [ + bind(_write_metadata_delayed, dask_delayed)(grp, zarr_attrs) + ] else: - grp.attrs.update(d) + grp.attrs.update(zarr_attrs) return dask_delayed else: return write_image( @@ -454,8 +475,9 @@ def _write_metadata_delayed(grp, d): group=grp, scaler=scaler, axes=axes, + fmt=fmt, compute=compute, - metadata=image_attrs, + metadata=image_attrs if image_attrs is not None else {}, coordinate_transformations=( [coordinate_transformations] if coordinate_transformations is not None @@ -554,60 +576,56 @@ def _write_zarr_labels( isinstance(labels, xr.DataArray) and isinstance(labels.data, da.Array) ): labels = rechunk(labels) + fmt = default_zarr_format() return write_image( labels, grp, scaler=scaler, axes=label_axes, + fmt=fmt, metadata=metadata, compute=compute, + coordinate_transformations=None, storage_options=storage_options, ) -def _read_zarr_attrs(multiscale0: zarr.Group) -> tuple[dict, dict, list[str]]: - """Read attributes from a Zarr multiscale group. +def _read_zarr_attrs(attrs) -> tuple[dict, dict, list[str]]: + """Read attributes from Zarr. This function reads and processes the attributes, coordinates, and dimensions from the first multiscale dataset in a Zarr group. It also handles physical pixel sizes and units if available. - :param multiscale0: The Zarr group containing the multiscale dataset. + :param attrs: Zarr attributes. :return: A tuple containing: - coords: Dictionary of coordinates. - attrs: Dictionary of attributes. - dims: List of dimension names. - - :example: - - .. code-block:: python - - import zarr - from scallops.zarr_io import _read_zarr_attrs - - # Create a Zarr group with multiscale attributes - store = zarr.DirectoryStore("example.zarr") - root = zarr.group(store=store) - multiscale0 = root.create_group("multiscales") - multiscale0.attrs["axes"] = [{"name": "x"}, {"name": "y"}, {"name": "z"}] - multiscale0.attrs["datasets"] = [ - {"coordinateTransformations": [{"scale": [1.0, 0.5, 0.5]}]} - ] - - # Read attributes from the multiscale group - coords, attrs, dims = _read_zarr_attrs(multiscale0) - print(coords) - print(attrs) - print(dims) """ - attrs = multiscale0.get("metadata") - if attrs is None: - attrs = {} + # v3 + # ome/omero for channel metadata + # ome/multiscales[0]/metadata for other metadata + + # v2: + # omero for channel metadata + # multiscales[0]/metadata for other metadata + + if "ome" in attrs: + attrs = attrs["ome"] + multiscales = attrs["multiscales"] + if len(multiscales) > 0: + multiscale0 = multiscales[0] + else: + return None, None, None + axes = multiscale0["axes"] dims = [axis["name"] for axis in axes] - - coords = {d: attrs[d] for d in dims if d in attrs and d != "c"} - if "omero" in attrs and "c" in dims: + metadata = multiscale0.get("metadata") + if metadata is None: + metadata = {} + coords = {d: metadata[d] for d in dims if d in metadata and d != "c"} + if "c" in dims and "omero" in attrs: channel_names = attrs["omero"].get("channels") if channel_names is not None: coords["c"] = [c["label"] for c in channel_names] @@ -624,9 +642,9 @@ def _read_zarr_attrs(multiscale0: zarr.Group) -> tuple[dict, dict, list[str]]: if len(space_indices_with_units) > 0: scale = multiscale0["datasets"][0]["coordinateTransformations"][0]["scale"] physical_pixel_sizes = tuple([scale[d] for d in space_indices_with_units]) - attrs["physical_pixel_sizes"] = physical_pixel_sizes - attrs["physical_pixel_units"] = tuple(units) - return coords, attrs, dims + metadata["physical_pixel_sizes"] = physical_pixel_sizes + metadata["physical_pixel_units"] = tuple(units) + return coords, metadata, dims def _read_ome_zarr_array( @@ -643,16 +661,19 @@ def _read_ome_zarr_array( node = zarr.open(node, mode="r") if node is None: raise ValueError(f"{_node} not found") - if "multiscales" in node.attrs: - dims = None - coords = {} - attrs = {} - multiscales = node.attrs["multiscales"] + # For zarr v3, everything is under the "ome" namespace + if "ome" in node.attrs or "multiscales" in node.attrs: + coords, attrs, dims = _read_zarr_attrs(node.attrs) + + multiscales = ( + node.attrs["multiscales"] + if "multiscales" in node.attrs + else node.attrs["ome"]["multiscales"] + ) key = "0" if len(multiscales) > 0: multiscale0 = multiscales[0] - coords, attrs, dims = _read_zarr_attrs(multiscale0) if "datasets" in multiscale0: tmp = multiscale0["datasets"] if len(tmp) > 0: @@ -667,7 +688,7 @@ def _read_ome_zarr_array( image_keys = list(images.keys()) if len(image_keys) == 1: return _read_ome_zarr_array(images[image_keys[0]]) - logger.warning("multiscales not found in attrs") + logger.warning(f"multiscales not found in attrs for {node} ") def read_ome_zarr_array( @@ -713,10 +734,11 @@ def open_ome_zarr(url: Path | str, mode: str = "a") -> zarr.Group | None: """ try: - loc = parse_url(url, mode=mode) + fmt = default_zarr_format() + loc = parse_url(url, mode=mode, fmt=fmt) if loc is None: return None - return zarr.open(loc.store, mode=mode) + return zarr.open(loc.store, mode=mode, zarr_format=fmt.zarr_format) except Exception as e: logger.error(f"Failed to open OME-Zarr store: {url}") raise e