Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@

# FIXME: workaround for https://github.com/Project-MONAI/MONAI/issues/5291
# from .video_dataset import CameraDataset, VideoDataset, VideoFileDataset
from .wsi_datasets import MaskedPatchWSIDataset, PatchWSIDataset, SlidingPatchWSIDataset
from .wsi_datasets import AnnotationPatchWSIDataset, MaskedPatchWSIDataset, PatchWSIDataset, SlidingPatchWSIDataset
from .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, TiffFileWSIReader, WSIReader

with contextlib.suppress(Exception):
Expand Down
165 changes: 164 additions & 1 deletion monai/data/wsi_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys

__all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"]
__all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset", "AnnotationPatchWSIDataset"]


class PatchWSIDataset(Dataset):
Expand Down Expand Up @@ -414,3 +414,166 @@ def _evaluate_patch_locations(self, sample):
{**sample, WSIPatchKeys.LOCATION.value: np.array(loc), ProbMapKeys.LOCATION.value: mask_loc}
for loc, mask_loc in zip(patch_locations, mask_locations)
]


class AnnotationPatchWSIDataset(Randomizable, PatchWSIDataset):
"""
This dataset extracts patches from whole slide images at locations sampled from user-provided
annotation masks with class-based weighted sampling.

Unlike `MaskedPatchWSIDataset` which uses automatic foreground detection, this dataset accepts
user-provided annotation masks where each pixel value represents a class label. Patches are
sampled based on configurable class weights, enabling balanced or custom sampling strategies.

Args:
data: the list of input samples including image and mask paths (see the note below for more details).
patch_size: the size of patch to be extracted from the whole slide image.
patch_level: the level at which the patches to be extracted (default to 0).
mask_level: the resolution level at which the annotation mask is provided.
num_patches_per_image: number of patches to sample per image.
sampling_weights: a dictionary mapping class labels (int) to sampling probabilities (float).
If None, uniform sampling across all annotated classes is used.
transform: transforms to be executed on input data.
include_label: whether to load and include labels in the output
center_location: whether the input location information is the position of the center of the patch
additional_meta_keys: the list of keys for items to be copied to the output metadata from the input data
reader: the module to be used for loading whole slide imaging. Defaults to cuCIM. If `reader` is

- a string, it defines the backend of `monai.data.WSIReader`.
- a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader,
- an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.

seed: random seed for reproducibility. Defaults to 0.
kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class

Note:
The input data has the following form as an example:

.. code-block:: python

[
{"image": "path/to/image1.tiff", "mask": "path/to/mask1.npy"},
{"image": "path/to/image2.tiff", "mask": "path/to/mask2.npy"},
]

The mask should be a numpy array (or loadable file) where each pixel value is an integer
class label. Background (class 0) is excluded from sampling by default.

"""

def __init__(
self,
data: Sequence,
patch_size: int | tuple[int, int] | None = None,
patch_level: int | None = None,
mask_level: int = 0,
num_patches_per_image: int = 100,
sampling_weights: dict[int, float] | None = None,
transform: Callable | None = None,
include_label: bool = True,
center_location: bool = False,
additional_meta_keys: Sequence[str] = (ProbMapKeys.LOCATION, ProbMapKeys.NAME),
reader="cuCIM",
seed: int = 0,
**kwargs,
):
super().__init__(
data=[],
patch_size=patch_size,
patch_level=patch_level,
transform=transform,
include_label=include_label,
center_location=center_location,
additional_meta_keys=additional_meta_keys,
reader=reader,
**kwargs,
)

self.mask_level = mask_level
self.num_patches_per_image = num_patches_per_image
self.sampling_weights = sampling_weights
self.set_random_state(seed)

self.data: list
self.image_data = list(data)
for sample in self.image_data:
patch_samples = self._sample_patch_locations(sample)
self.data.extend(patch_samples)

def _load_mask(self, sample: dict) -> np.ndarray:
"""Load the annotation mask from the sample."""
mask_path = sample["mask"]
if isinstance(mask_path, np.ndarray):
return mask_path
return np.load(mask_path)

def _sample_patch_locations(self, sample: dict) -> list[dict]:
"""Sample patch locations from the annotation mask based on class weights."""
patch_size = self._get_size(sample)
patch_level = self._get_level(sample)
wsi_obj = self._get_wsi_object(sample)

# Load the annotation mask
mask = self._load_mask(sample)

# Get unique classes (exclude background=0)
classes = np.unique(mask)
classes = classes[classes != 0]

if len(classes) == 0:
return []

# Determine sampling weights
if self.sampling_weights is not None:
weights = np.array([self.sampling_weights.get(int(c), 0.0) for c in classes])
else:
# Uniform sampling across classes
weights = np.ones(len(classes))

# Normalize weights
weight_sum = weights.sum()
if weight_sum == 0:
return []
weights = weights / weight_sum

# Pre-compute locations for each class
class_locations: dict[int, np.ndarray] = {}
for c in classes:
locs = np.vstack(np.where(mask == c)).T
if len(locs) > 0:
class_locations[int(c)] = locs

# Sample patches
mask_ratio = self.wsi_reader.get_downsample_ratio(wsi_obj, self.mask_level)
patch_ratio = self.wsi_reader.get_downsample_ratio(wsi_obj, patch_level)
patch_size_0 = np.array([p * patch_ratio for p in patch_size])

patch_samples = []
for _ in range(self.num_patches_per_image):
# Sample a class
class_idx = self.R.choice(len(classes), p=weights)
chosen_class = int(classes[class_idx])

if chosen_class not in class_locations:
continue

# Sample a location from that class
locs = class_locations[chosen_class]
loc_idx = self.R.randint(len(locs))
mask_loc = locs[loc_idx]

# Convert mask location to image location at level 0
patch_loc = np.round((mask_loc + 0.5) * float(mask_ratio) - patch_size_0 // 2).astype(int)

patch_sample = {
**sample,
WSIPatchKeys.LOCATION.value: patch_loc,
WSIPatchKeys.SIZE.value: patch_size,
WSIPatchKeys.LEVEL.value: patch_level,
ProbMapKeys.LOCATION.value: mask_loc,
ProbMapKeys.NAME.value: os.path.basename(sample[CommonKeys.IMAGE]),
CommonKeys.LABEL: chosen_class,
}
patch_samples.append(patch_sample)

return patch_samples
170 changes: 170 additions & 0 deletions tests/data/test_annotation_patch_wsi_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import os
import tempfile
import unittest
from pathlib import Path
from unittest import skipUnless

import numpy as np

from monai.data import AnnotationPatchWSIDataset
from monai.utils import CommonKeys, optional_import, set_determinism
from tests.test_utils import download_url_or_skip_test, testing_data_config

set_determinism(0)

cucim, has_cucim = optional_import("cucim")
has_cucim = has_cucim and hasattr(cucim, "CuImage")
_, has_osl = optional_import("openslide")
_, has_tiff = optional_import("tifffile", name="imwrite")
_, has_codec = optional_import("imagecodecs")
has_tiff = has_tiff and has_codec

FILE_KEY = "wsi_generic_tiff"
FILE_URL = testing_data_config("images", FILE_KEY, "url")
TESTS_PATH = Path(__file__).parents[1]
FILE_PATH = os.path.join(TESTS_PATH, "testing_data", f"temp_{FILE_KEY}.tiff")


@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!")
def setUpModule():
hash_type = testing_data_config("images", FILE_KEY, "hash_type")
hash_val = testing_data_config("images", FILE_KEY, "hash_val")
download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)


class AnnotationPatchWSIDatasetTests:
class Tests(unittest.TestCase):
backend = None

def setUp(self):
# Create a temporary annotation mask with 3 classes
self.mask = np.zeros((128, 179), dtype=np.int32)
self.mask[10:50, 10:80] = 1 # class 1: tumor
self.mask[60:100, 90:170] = 2 # class 2: stroma
self.mask[100:120, 20:60] = 3 # class 3: necrosis

self.mask_file = tempfile.NamedTemporaryFile(suffix=".npy", delete=False)
np.save(self.mask_file.name, self.mask)

def tearDown(self):
os.unlink(self.mask_file.name)

def test_uniform_sampling(self):
"""Test that patches are sampled from all classes with uniform weights."""
data = [{"image": FILE_PATH, "mask": self.mask_file.name}]
dataset = AnnotationPatchWSIDataset(
data=data,
patch_size=(2, 2),
patch_level=8,
mask_level=8,
num_patches_per_image=50,
reader=self.backend,
seed=42,
)
self.assertEqual(len(dataset), 50)

# Check that labels are from the expected classes
labels = set()
for i in range(len(dataset)):
sample = dataset[i]
self.assertIn(CommonKeys.IMAGE, sample)
self.assertIn(CommonKeys.LABEL, sample)
labels.add(int(sample[CommonKeys.LABEL].item()))
# With 50 samples and uniform weights, we should see all 3 classes
self.assertEqual(labels, {1, 2, 3})

def test_weighted_sampling(self):
"""Test that sampling weights control class distribution."""
data = [{"image": FILE_PATH, "mask": self.mask_file.name}]
# Only sample from class 1
dataset = AnnotationPatchWSIDataset(
data=data,
patch_size=(2, 2),
patch_level=8,
mask_level=8,
num_patches_per_image=20,
sampling_weights={1: 1.0, 2: 0.0, 3: 0.0},
reader=self.backend,
seed=42,
)
self.assertEqual(len(dataset), 20)
for i in range(len(dataset)):
sample = dataset[i]
self.assertEqual(int(sample[CommonKeys.LABEL].item()), 1)

def test_mask_as_array(self):
"""Test that mask can be passed directly as a numpy array."""
data = [{"image": FILE_PATH, "mask": self.mask}]
dataset = AnnotationPatchWSIDataset(
data=data,
patch_size=(2, 2),
patch_level=8,
mask_level=8,
num_patches_per_image=10,
reader=self.backend,
seed=42,
)
self.assertEqual(len(dataset), 10)

def test_empty_mask(self):
"""Test that an all-zero mask produces no patches."""
empty_mask = np.zeros((128, 179), dtype=np.int32)
data = [{"image": FILE_PATH, "mask": empty_mask}]
dataset = AnnotationPatchWSIDataset(
data=data,
patch_size=(2, 2),
patch_level=8,
mask_level=8,
num_patches_per_image=10,
reader=self.backend,
seed=42,
)
self.assertEqual(len(dataset), 0)

def test_metadata_keys(self):
"""Test that output contains expected metadata."""
data = [{"image": FILE_PATH, "mask": self.mask_file.name}]
dataset = AnnotationPatchWSIDataset(
data=data,
patch_size=(2, 2),
patch_level=8,
mask_level=8,
num_patches_per_image=5,
reader=self.backend,
seed=42,
)
sample = dataset[0]
self.assertIn(CommonKeys.IMAGE, sample)
self.assertIn(CommonKeys.LABEL, sample)


@skipUnless(has_cucim, "Requires cucim")
class AnnotationPatchWSIDatasetCuCIMTests(AnnotationPatchWSIDatasetTests.Tests):
backend = "cucim"


@skipUnless(has_osl, "Requires openslide")
class AnnotationPatchWSIDatasetOpenSlideTests(AnnotationPatchWSIDatasetTests.Tests):
backend = "openslide"


@skipUnless(has_tiff, "Requires tifffile")
class AnnotationPatchWSIDatasetTiffFileTests(AnnotationPatchWSIDatasetTests.Tests):
backend = "tifffile"


if __name__ == "__main__":
unittest.main()
Loading