diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 19b4df1..e93e7d0 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -1274,10 +1274,11 @@ def n4_bias_field_correction( import ants.ops.bias_correction as bc # install antspyx not ants! except ModuleNotFoundError: import ants.utils.bias_correction as bc # install antspyx not ants! - from ants.utils.convert_nibabel import from_nibabel from scipy.ndimage import binary_dilation, generate_binary_structure + + from TPTBox.core.internal import ants_to_nifti, nifti_to_ants dtype = self.dtype - input_ants:ants.ANTsImage = from_nibabel(nib.nifti1.Nifti1Image(self.get_array(),self.affine)) + input_ants:ants.ANTsImage = self.to_ants() if threshold != 0: mask_arr = self.get_array() mask_arr[mask_arr < threshold] = 0 @@ -1286,7 +1287,7 @@ def n4_bias_field_correction( struct = generate_binary_structure(3, 3) mask_arr = binary_dilation(mask_arr.copy(), structure=struct, iterations=3) mask_arr = mask_arr.astype(np.uint8) - mask:ants.ANTsImage = from_nibabel(nib.nifti1.Nifti1Image(mask_arr,self.affine))#self.set_array(mask,verbose=False).nii + mask:ants.ANTsImage = self.set_array(mask_arr,verbose=False).to_ants() mask = mask.set_spacing(input_ants.spacing) # type: ignore out = bc.n4_bias_field_correction( input_ants, @@ -1299,11 +1300,10 @@ def n4_bias_field_correction( ) - - out_nib:Nifti1Image = out.to_nibabel() + out_nib:Nifti1Image = ants_to_nifti(out) if crop: # Crop to regions that had a normalization applied. Removes a lot of dead space - dif = NII((input_ants - out).to_nibabel()) + dif = NII(ants_to_nifti(input_ants - out)) dif_arr = dif.get_array() dif_arr[dif_arr != 0] = 1 dif.set_array_(dif_arr,verbose=verbose) @@ -1513,7 +1513,14 @@ def to_ants(self) -> Any: log.print_error() log.on_fail("run 'pip install antspyx' to install hf-deepali") raise - return ants.from_nibabel(self.nii) + try: + from ants.utils.convert_nibabel import from_nibabel + + return from_nibabel(self.nii) + except ModuleNotFoundError: + from ants.utils.nibabel_nifti_to_ants import from_nibabel_nifti + + return from_nibabel_nifti(self.nii) def to_simpleITK(self) -> Any: """Converts this NII to a SimpleITK image. diff --git a/unit_tests/test_nrrd.py b/unit_tests/test_nrrd.py index 2d121a5..53a9653 100644 --- a/unit_tests/test_nrrd.py +++ b/unit_tests/test_nrrd.py @@ -20,7 +20,7 @@ class TestAnts(unittest.TestCase): @unittest.skipIf(not has_ants, "requires ants to be installed") - def test_segmentation_CT(self): + def test_ants_segmentation_CT(self): """Test round-trip for Segmentation.seg.nrrd.""" ct, subreg, vert = get_nii_paths_ct() from TPTBox import NII, to_nii @@ -33,6 +33,25 @@ def test_segmentation_CT(self): assert np.isclose(nii.affine, nii2.affine).all() assert np.isclose(nii.get_array(), nii.get_array()).all() + nii = to_nii(ct) + nii2 = ants_to_nifti(nii.to_ants(), nii.header) + nii2 = NII(nii2) + assert nii.orientation == nii2.orientation + assert np.isclose(nii.affine, nii2.affine).all() + assert np.isclose(nii.get_array(), nii.get_array()).all() + + @unittest.skipIf(not has_ants, "requires ants to be installed") + def test_n4_bias_field_correction(self): + """Test round-trip for Segmentation.seg.nrrd.""" + ct, subreg, vert = get_nii_paths_ct() + from TPTBox import NII, to_nii + from TPTBox.core.internal import ants_to_nifti, nifti_to_ants + + nii = to_nii(ct) + nii2 = nii.n4_bias_field_correction() + + assert nii2.shape == nii.shape + # @unittest.skipIf(not has_ants, "requires spineps to be installed") # def test_raf_ants(): # ct, subreg, vert = get_nii_paths_ct()