From 8b7a07e63b9b87cdf292e1a4114bf71c3cd4e754 Mon Sep 17 00:00:00 2001 From: Fanqi Cheng Date: Sun, 31 May 2026 18:53:01 -0700 Subject: [PATCH 1/2] Add AnnotationPatchWSIDataset with class-weighted mask sampling Add a new dataset class that samples patches from whole slide images based on user-provided annotation masks with configurable class weights. Unlike MaskedPatchWSIDataset which only detects foreground tissue, this class supports multi-class annotations and weighted sampling. Fixes #7402 Signed-off-by: Fanqi Cheng --- monai/data/__init__.py | 2 +- monai/data/wsi_datasets.py | 165 ++++++++++++++++- .../data/test_annotation_patch_wsi_dataset.py | 171 ++++++++++++++++++ 3 files changed, 336 insertions(+), 2 deletions(-) create mode 100644 tests/data/test_annotation_patch_wsi_dataset.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 971d5121f7..2ba014331e 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -115,7 +115,7 @@ # FIXME: workaround for https://github.com/Project-MONAI/MONAI/issues/5291 # from .video_dataset import CameraDataset, VideoDataset, VideoFileDataset -from .wsi_datasets import MaskedPatchWSIDataset, PatchWSIDataset, SlidingPatchWSIDataset +from .wsi_datasets import AnnotationPatchWSIDataset, MaskedPatchWSIDataset, PatchWSIDataset, SlidingPatchWSIDataset from .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, TiffFileWSIReader, WSIReader with contextlib.suppress(Exception): diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 2ee8c9d363..83851c4b84 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -26,7 +26,7 @@ from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys -__all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"] +__all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset", "AnnotationPatchWSIDataset"] class PatchWSIDataset(Dataset): @@ -414,3 +414,166 @@ def _evaluate_patch_locations(self, sample): {**sample, WSIPatchKeys.LOCATION.value: np.array(loc), ProbMapKeys.LOCATION.value: mask_loc} for loc, mask_loc in zip(patch_locations, mask_locations) ] + + +class AnnotationPatchWSIDataset(Randomizable, PatchWSIDataset): + """ + This dataset extracts patches from whole slide images at locations sampled from user-provided + annotation masks with class-based weighted sampling. + + Unlike `MaskedPatchWSIDataset` which uses automatic foreground detection, this dataset accepts + user-provided annotation masks where each pixel value represents a class label. Patches are + sampled based on configurable class weights, enabling balanced or custom sampling strategies. + + Args: + data: the list of input samples including image and mask paths (see the note below for more details). + patch_size: the size of patch to be extracted from the whole slide image. + patch_level: the level at which the patches to be extracted (default to 0). + mask_level: the resolution level at which the annotation mask is provided. + num_patches_per_image: number of patches to sample per image. + sampling_weights: a dictionary mapping class labels (int) to sampling probabilities (float). + If None, uniform sampling across all annotated classes is used. + transform: transforms to be executed on input data. + include_label: whether to load and include labels in the output + center_location: whether the input location information is the position of the center of the patch + additional_meta_keys: the list of keys for items to be copied to the output metadata from the input data + reader: the module to be used for loading whole slide imaging. Defaults to cuCIM. If `reader` is + + - a string, it defines the backend of `monai.data.WSIReader`. + - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader, + - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader. + + seed: random seed for reproducibility. Defaults to 0. + kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class + + Note: + The input data has the following form as an example: + + .. code-block:: python + + [ + {"image": "path/to/image1.tiff", "mask": "path/to/mask1.npy"}, + {"image": "path/to/image2.tiff", "mask": "path/to/mask2.npy"}, + ] + + The mask should be a numpy array (or loadable file) where each pixel value is an integer + class label. Background (class 0) is excluded from sampling by default. + + """ + + def __init__( + self, + data: Sequence, + patch_size: int | tuple[int, int] | None = None, + patch_level: int | None = None, + mask_level: int = 0, + num_patches_per_image: int = 100, + sampling_weights: dict[int, float] | None = None, + transform: Callable | None = None, + include_label: bool = True, + center_location: bool = False, + additional_meta_keys: Sequence[str] = (ProbMapKeys.LOCATION, ProbMapKeys.NAME), + reader="cuCIM", + seed: int = 0, + **kwargs, + ): + super().__init__( + data=[], + patch_size=patch_size, + patch_level=patch_level, + transform=transform, + include_label=include_label, + center_location=center_location, + additional_meta_keys=additional_meta_keys, + reader=reader, + **kwargs, + ) + + self.mask_level = mask_level + self.num_patches_per_image = num_patches_per_image + self.sampling_weights = sampling_weights + self.set_random_state(seed) + + self.data: list + self.image_data = list(data) + for sample in self.image_data: + patch_samples = self._sample_patch_locations(sample) + self.data.extend(patch_samples) + + def _load_mask(self, sample: dict) -> np.ndarray: + """Load the annotation mask from the sample.""" + mask_path = sample["mask"] + if isinstance(mask_path, np.ndarray): + return mask_path + return np.load(mask_path) + + def _sample_patch_locations(self, sample: dict) -> list[dict]: + """Sample patch locations from the annotation mask based on class weights.""" + patch_size = self._get_size(sample) + patch_level = self._get_level(sample) + wsi_obj = self._get_wsi_object(sample) + + # Load the annotation mask + mask = self._load_mask(sample) + + # Get unique classes (exclude background=0) + classes = np.unique(mask) + classes = classes[classes != 0] + + if len(classes) == 0: + return [] + + # Determine sampling weights + if self.sampling_weights is not None: + weights = np.array([self.sampling_weights.get(int(c), 0.0) for c in classes]) + else: + # Uniform sampling across classes + weights = np.ones(len(classes)) + + # Normalize weights + weight_sum = weights.sum() + if weight_sum == 0: + return [] + weights = weights / weight_sum + + # Pre-compute locations for each class + class_locations: dict[int, np.ndarray] = {} + for c in classes: + locs = np.vstack(np.where(mask == c)).T + if len(locs) > 0: + class_locations[int(c)] = locs + + # Sample patches + mask_ratio = self.wsi_reader.get_downsample_ratio(wsi_obj, self.mask_level) + patch_ratio = self.wsi_reader.get_downsample_ratio(wsi_obj, patch_level) + patch_size_0 = np.array([p * patch_ratio for p in patch_size]) + + patch_samples = [] + for _ in range(self.num_patches_per_image): + # Sample a class + class_idx = self.R.choice(len(classes), p=weights) + chosen_class = int(classes[class_idx]) + + if chosen_class not in class_locations: + continue + + # Sample a location from that class + locs = class_locations[chosen_class] + loc_idx = self.R.randint(len(locs)) + mask_loc = locs[loc_idx] + + # Convert mask location to image location at level 0 + patch_loc = np.round((mask_loc + 0.5) * float(mask_ratio) - patch_size_0 // 2).astype(int) + + patch_sample = { + **sample, + WSIPatchKeys.LOCATION.value: patch_loc, + WSIPatchKeys.SIZE.value: patch_size, + WSIPatchKeys.LEVEL.value: patch_level, + ProbMapKeys.LOCATION.value: mask_loc, + ProbMapKeys.NAME.value: os.path.basename(sample[CommonKeys.IMAGE]), + CommonKeys.LABEL: chosen_class, + } + patch_samples.append(patch_sample) + + return patch_samples diff --git a/tests/data/test_annotation_patch_wsi_dataset.py b/tests/data/test_annotation_patch_wsi_dataset.py new file mode 100644 index 0000000000..182e23addc --- /dev/null +++ b/tests/data/test_annotation_patch_wsi_dataset.py @@ -0,0 +1,171 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest +from pathlib import Path +from unittest import skipUnless + +import numpy as np +from parameterized import parameterized + +from monai.data import AnnotationPatchWSIDataset +from monai.utils import CommonKeys, ProbMapKeys, WSIPatchKeys, optional_import, set_determinism +from tests.test_utils import download_url_or_skip_test, testing_data_config + +set_determinism(0) + +cucim, has_cucim = optional_import("cucim") +has_cucim = has_cucim and hasattr(cucim, "CuImage") +_, has_osl = optional_import("openslide") +_, has_tiff = optional_import("tifffile", name="imwrite") +_, has_codec = optional_import("imagecodecs") +has_tiff = has_tiff and has_codec + +FILE_KEY = "wsi_generic_tiff" +FILE_URL = testing_data_config("images", FILE_KEY, "url") +TESTS_PATH = Path(__file__).parents[1] +FILE_PATH = os.path.join(TESTS_PATH, "testing_data", f"temp_{FILE_KEY}.tiff") + + +@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!") +def setUpModule(): + hash_type = testing_data_config("images", FILE_KEY, "hash_type") + hash_val = testing_data_config("images", FILE_KEY, "hash_val") + download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val) + + +class AnnotationPatchWSIDatasetTests: + class Tests(unittest.TestCase): + backend = None + + def setUp(self): + # Create a temporary annotation mask with 3 classes + self.mask = np.zeros((128, 179), dtype=np.int32) + self.mask[10:50, 10:80] = 1 # class 1: tumor + self.mask[60:100, 90:170] = 2 # class 2: stroma + self.mask[100:120, 20:60] = 3 # class 3: necrosis + + self.mask_file = tempfile.NamedTemporaryFile(suffix=".npy", delete=False) + np.save(self.mask_file.name, self.mask) + + def tearDown(self): + os.unlink(self.mask_file.name) + + def test_uniform_sampling(self): + """Test that patches are sampled from all classes with uniform weights.""" + data = [{"image": FILE_PATH, "mask": self.mask_file.name}] + dataset = AnnotationPatchWSIDataset( + data=data, + patch_size=(2, 2), + patch_level=8, + mask_level=8, + num_patches_per_image=50, + reader=self.backend, + seed=42, + ) + self.assertEqual(len(dataset), 50) + + # Check that labels are from the expected classes + labels = set() + for i in range(len(dataset)): + sample = dataset[i] + self.assertIn(CommonKeys.IMAGE, sample) + self.assertIn(CommonKeys.LABEL, sample) + labels.add(int(sample[CommonKeys.LABEL].item())) + # With 50 samples and uniform weights, we should see all 3 classes + self.assertEqual(labels, {1, 2, 3}) + + def test_weighted_sampling(self): + """Test that sampling weights control class distribution.""" + data = [{"image": FILE_PATH, "mask": self.mask_file.name}] + # Only sample from class 1 + dataset = AnnotationPatchWSIDataset( + data=data, + patch_size=(2, 2), + patch_level=8, + mask_level=8, + num_patches_per_image=20, + sampling_weights={1: 1.0, 2: 0.0, 3: 0.0}, + reader=self.backend, + seed=42, + ) + self.assertEqual(len(dataset), 20) + for i in range(len(dataset)): + sample = dataset[i] + self.assertEqual(int(sample[CommonKeys.LABEL].item()), 1) + + def test_mask_as_array(self): + """Test that mask can be passed directly as a numpy array.""" + data = [{"image": FILE_PATH, "mask": self.mask}] + dataset = AnnotationPatchWSIDataset( + data=data, + patch_size=(2, 2), + patch_level=8, + mask_level=8, + num_patches_per_image=10, + reader=self.backend, + seed=42, + ) + self.assertEqual(len(dataset), 10) + + def test_empty_mask(self): + """Test that an all-zero mask produces no patches.""" + empty_mask = np.zeros((128, 179), dtype=np.int32) + data = [{"image": FILE_PATH, "mask": empty_mask}] + dataset = AnnotationPatchWSIDataset( + data=data, + patch_size=(2, 2), + patch_level=8, + mask_level=8, + num_patches_per_image=10, + reader=self.backend, + seed=42, + ) + self.assertEqual(len(dataset), 0) + + def test_metadata_keys(self): + """Test that output contains expected metadata.""" + data = [{"image": FILE_PATH, "mask": self.mask_file.name}] + dataset = AnnotationPatchWSIDataset( + data=data, + patch_size=(2, 2), + patch_level=8, + mask_level=8, + num_patches_per_image=5, + reader=self.backend, + seed=42, + ) + sample = dataset[0] + self.assertIn(CommonKeys.IMAGE, sample) + self.assertIn(CommonKeys.LABEL, sample) + + +@skipUnless(has_cucim, "Requires cucim") +class AnnotationPatchWSIDatasetCuCIMTests(AnnotationPatchWSIDatasetTests.Tests): + backend = "cucim" + + +@skipUnless(has_osl, "Requires openslide") +class AnnotationPatchWSIDatasetOpenSlideTests(AnnotationPatchWSIDatasetTests.Tests): + backend = "openslide" + + +@skipUnless(has_tiff, "Requires tifffile") +class AnnotationPatchWSIDatasetTiffFileTests(AnnotationPatchWSIDatasetTests.Tests): + backend = "tifffile" + + +if __name__ == "__main__": + unittest.main() From e8f4b60b4b89689f23e3427389547b174f66089a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Jun 2026 02:04:49 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/data/test_annotation_patch_wsi_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/data/test_annotation_patch_wsi_dataset.py b/tests/data/test_annotation_patch_wsi_dataset.py index 182e23addc..f07899923a 100644 --- a/tests/data/test_annotation_patch_wsi_dataset.py +++ b/tests/data/test_annotation_patch_wsi_dataset.py @@ -18,10 +18,9 @@ from unittest import skipUnless import numpy as np -from parameterized import parameterized from monai.data import AnnotationPatchWSIDataset -from monai.utils import CommonKeys, ProbMapKeys, WSIPatchKeys, optional_import, set_determinism +from monai.utils import CommonKeys, optional_import, set_determinism from tests.test_utils import download_url_or_skip_test, testing_data_config set_determinism(0)