diff --git a/.gitignore b/.gitignore index 3da001d0ce..bd117cc321 100644 --- a/.gitignore +++ b/.gitignore @@ -135,6 +135,8 @@ tests/testing_data/endo.mp4 tests/testing_data/ultrasound.avi tests/testing_data/train_data_stats.yaml tests/testing_data/eval_data_stats.yaml +tests/testing_data/CT_2D_head_fixed.mha +tests/testing_data/CT_2D_head_moving.mha # clang format tool .clang-format-bin/ diff --git a/docs/source/data.rst b/docs/source/data.rst index cd4daab889..69e694a37b 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -252,6 +252,11 @@ N-Dim Fourier Transform .. autofunction:: monai.data.fft_utils.fftn_centered .. autofunction:: monai.data.fft_utils.ifftn_centered +ITK Torch Bridge +~~~~~~~~~~~~~~~~ +.. automodule:: monai.data.itk_torch_bridge + :members: + Meta Object ----------- diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 9e91397331..3678270232 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -61,6 +61,14 @@ resolve_writer, ) from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer +from .itk_torch_bridge import ( + get_itk_image_center, + itk_image_to_metatensor, + itk_to_monai_affine, + metatensor_to_itk_image, + monai_to_itk_affine, + monai_to_itk_ddf, +) from .meta_obj import MetaObj, get_track_meta, set_track_meta from .meta_tensor import MetaTensor from .samplers import DistributedSampler, DistributedWeightedRandomSampler diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 736baff538..138689a8b3 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -204,7 +204,7 @@ class ITKReader(ImageReader): def __init__( self, - channel_dim: int | None = None, + channel_dim: str | int | None = None, series_name: str = "", reverse_indexing: bool = False, series_meta: bool = False, @@ -366,7 +366,7 @@ def _get_spatial_shape(self, img): sr = itk.array_from_matrix(img.GetDirection()).shape[0] sr = max(min(sr, 3), 1) _size = list(itk.size(img)) - if self.channel_dim is not None: + if isinstance(self.channel_dim, int): _size.pop(self.channel_dim) return np.asarray(_size[:sr]) diff --git a/monai/data/itk_torch_bridge.py b/monai/data/itk_torch_bridge.py new file mode 100644 index 0000000000..3dc25ad0bd --- /dev/null +++ b/monai/data/itk_torch_bridge.py @@ -0,0 +1,338 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +import numpy as np +import torch + +from monai.config.type_definitions import DtypeLike +from monai.data import ITKReader, ITKWriter +from monai.data.meta_tensor import MetaTensor +from monai.transforms import EnsureChannelFirst +from monai.utils import convert_to_dst_type, optional_import + +if TYPE_CHECKING: + import itk + + has_itk = True +else: + itk, has_itk = optional_import("itk") + +__all__ = [ + "itk_image_to_metatensor", + "metatensor_to_itk_image", + "itk_to_monai_affine", + "monai_to_itk_affine", + "get_itk_image_center", + "monai_to_itk_ddf", +] + + +def itk_image_to_metatensor( + image, channel_dim: str | int | None = None, dtype: DtypeLike | torch.dtype = float +) -> MetaTensor: + """ + Converts an ITK image to a MetaTensor object. + + Args: + image: The ITK image to be converted. + channel_dim: the channel dimension of the input image, default is None. + This is used to set original_channel_dim in the metadata, EnsureChannelFirst reads this field. + If None, the channel_dim is inferred automatically. + If the input array doesn't have a channel dim, this value should be ``'no_channel'``. + dtype: output dtype, defaults to the Python built-in `float`. + + Returns: + A MetaTensor object containing the array data and metadata in ChannelFirst format. + """ + reader = ITKReader(affine_lps_to_ras=False, channel_dim=channel_dim) + image_array, meta_data = reader.get_data(image) + image_array = convert_to_dst_type(image_array, dst=image_array, dtype=dtype)[0] + metatensor = MetaTensor.ensure_torch_and_prune_meta(image_array, meta_data) + metatensor = EnsureChannelFirst(channel_dim=channel_dim)(metatensor) + + return cast(MetaTensor, metatensor) + + +def metatensor_to_itk_image( + meta_tensor: MetaTensor, channel_dim: int | None = 0, dtype: DtypeLike = np.float32, **kwargs +): + """ + Converts a MetaTensor object to an ITK image. Expects the MetaTensor to be in ChannelFirst format. + + Args: + meta_tensor: The MetaTensor to be converted. + channel_dim: channel dimension of the data array, defaults to ``0`` (Channel-first). + ``None`` indicates no channel dimension. This is used to create a Vector Image if it is not ``None``. + dtype: output data type, defaults to `np.float32`. + kwargs: additional keyword arguments. Currently `itk.GetImageFromArray` will get ``ttype`` from this dictionary. + + Returns: + The ITK image. + + See also: :py:func:`ITKWriter.create_backend_obj` + """ + writer = ITKWriter(output_dtype=dtype, affine_lps_to_ras=False) + writer.set_data_array(data_array=meta_tensor.data, channel_dim=channel_dim, squeeze_end_dims=True) + return writer.create_backend_obj( + writer.data_obj, + channel_dim=writer.channel_dim, + affine=meta_tensor.affine, + affine_lps_to_ras=False, # False if the affine is in itk convention + dtype=writer.output_dtype, + kwargs=kwargs, + ) + + +def itk_to_monai_affine(image, matrix, translation, center_of_rotation=None, reference_image=None) -> torch.Tensor: + """ + Converts an ITK affine matrix (2x2 for 2D or 3x3 for 3D matrix and translation vector) to a MONAI affine matrix. + + Args: + image: The ITK image object. This is used to extract the spacing and direction information. + matrix: The 2x2 or 3x3 ITK affine matrix. + translation: The 2-element or 3-element ITK affine translation vector. + center_of_rotation: The center of rotation. If provided, the affine + matrix will be adjusted to account for the difference + between the center of the image and the center of rotation. + reference_image: The coordinate space that matrix and translation were defined + in respect to. If not supplied, the coordinate space of image + is used. + + Returns: + A 4x4 MONAI affine matrix. + """ + + _assert_itk_regions_match_array(image) + ndim = image.ndim + # If there is a reference image, compute an affine matrix that maps the image space to the + # reference image space. + if reference_image: + reference_affine_matrix = _compute_reference_space_affine_matrix(image, reference_image) + else: + reference_affine_matrix = torch.eye(ndim + 1, dtype=torch.float64) + + # Create affine matrix that includes translation + affine_matrix = torch.eye(ndim + 1, dtype=torch.float64) + affine_matrix[:ndim, :ndim] = torch.tensor(matrix, dtype=torch.float64) + affine_matrix[:ndim, ndim] = torch.tensor(translation, dtype=torch.float64) + + # Adjust offset when center of rotation is different from center of the image + if center_of_rotation: + offset_matrix, inverse_offset_matrix = _compute_offset_matrix(image, center_of_rotation) + affine_matrix = inverse_offset_matrix @ affine_matrix @ offset_matrix + + # Adjust direction + direction_matrix, inverse_direction_matrix = _compute_direction_matrix(image) + affine_matrix = inverse_direction_matrix @ affine_matrix @ direction_matrix + + # Adjust based on spacing. It is required because MONAI does not update the + # pixel array according to the spacing after a transformation. For example, + # a rotation of 90deg for an image with different spacing along the two axis + # will just rotate the image array by 90deg without also scaling accordingly. + spacing_matrix, inverse_spacing_matrix = _compute_spacing_matrix(image) + affine_matrix = inverse_spacing_matrix @ affine_matrix @ spacing_matrix + + return affine_matrix @ reference_affine_matrix + + +def monai_to_itk_affine(image, affine_matrix, center_of_rotation=None): + """ + Converts a MONAI affine matrix to an ITK affine matrix (2x2 for 2D or 3x3 for + 3D matrix and translation vector). See also 'itk_to_monai_affine'. + + Args: + image: The ITK image object. This is used to extract the spacing and direction information. + affine_matrix: The 3x3 for 2D or 4x4 for 3D MONAI affine matrix. + center_of_rotation: The center of rotation. If provided, the affine + matrix will be adjusted to account for the difference + between the center of the image and the center of rotation. + + Returns: + The ITK matrix and the translation vector. + """ + _assert_itk_regions_match_array(image) + + # Adjust spacing + spacing_matrix, inverse_spacing_matrix = _compute_spacing_matrix(image) + affine_matrix = spacing_matrix @ affine_matrix @ inverse_spacing_matrix + + # Adjust direction + direction_matrix, inverse_direction_matrix = _compute_direction_matrix(image) + affine_matrix = direction_matrix @ affine_matrix @ inverse_direction_matrix + + # Adjust offset when center of rotation is different from center of the image + if center_of_rotation: + offset_matrix, inverse_offset_matrix = _compute_offset_matrix(image, center_of_rotation) + affine_matrix = offset_matrix @ affine_matrix @ inverse_offset_matrix + + ndim = image.ndim + matrix = affine_matrix[:ndim, :ndim].numpy() + translation = affine_matrix[:ndim, ndim].tolist() + + return matrix, translation + + +def get_itk_image_center(image): + """ + Calculates the center of the ITK image based on its origin, size, and spacing. + This center is equivalent to the implicit image center that MONAI uses. + + Args: + image: The ITK image. + + Returns: + The center of the image as a list of coordinates. + """ + image_size = np.asarray(image.GetLargestPossibleRegion().GetSize(), np.float32) + spacing = np.asarray(image.GetSpacing()) + origin = np.asarray(image.GetOrigin()) + center = image.GetDirection() @ ((image_size / 2 - 0.5) * spacing) + origin + + return center.tolist() + + +def _assert_itk_regions_match_array(image): + # Note: Make it more compact? Also, are there redundant checks? + largest_region = image.GetLargestPossibleRegion() + buffered_region = image.GetBufferedRegion() + requested_region = image.GetRequestedRegion() + + largest_region_size = np.array(largest_region.GetSize()) + buffered_region_size = np.array(buffered_region.GetSize()) + requested_region_size = np.array(requested_region.GetSize()) + array_size = np.array(image.shape)[::-1] + + largest_region_index = np.array(largest_region.GetIndex()) + buffered_region_index = np.array(buffered_region.GetIndex()) + requested_region_index = np.array(requested_region.GetIndex()) + + indices_are_zeros = ( + np.all(largest_region_index == 0) and np.all(buffered_region_index == 0) and np.all(requested_region_index == 0) + ) + + sizes_match = ( + np.array_equal(array_size, largest_region_size) + and np.array_equal(largest_region_size, buffered_region_size) + and np.array_equal(buffered_region_size, requested_region_size) + ) + + if not indices_are_zeros: + raise AssertionError("ITK-MONAI bridge: non-zero ITK region indices encountered") + if not sizes_match: + raise AssertionError("ITK-MONAI bridge: ITK regions should be of the same shape") + + +def _compute_offset_matrix(image, center_of_rotation) -> tuple[torch.Tensor, torch.Tensor]: + ndim = image.ndim + offset = np.asarray(get_itk_image_center(image)) - np.asarray(center_of_rotation) + offset_matrix = torch.eye(ndim + 1, dtype=torch.float64) + offset_matrix[:ndim, ndim] = torch.tensor(offset, dtype=torch.float64) + inverse_offset_matrix = torch.eye(ndim + 1, dtype=torch.float64) + inverse_offset_matrix[:ndim, ndim] = -torch.tensor(offset, dtype=torch.float64) + + return offset_matrix, inverse_offset_matrix + + +def _compute_spacing_matrix(image) -> tuple[torch.Tensor, torch.Tensor]: + ndim = image.ndim + spacing = np.asarray(image.GetSpacing(), dtype=np.float64) + spacing_matrix = torch.eye(ndim + 1, dtype=torch.float64) + inverse_spacing_matrix = torch.eye(ndim + 1, dtype=torch.float64) + for i, e in enumerate(spacing): + spacing_matrix[i, i] = e + inverse_spacing_matrix[i, i] = 1 / e + + return spacing_matrix, inverse_spacing_matrix + + +def _compute_direction_matrix(image) -> tuple[torch.Tensor, torch.Tensor]: + ndim = image.ndim + direction = itk.array_from_matrix(image.GetDirection()) + direction_matrix = torch.eye(ndim + 1, dtype=torch.float64) + direction_matrix[:ndim, :ndim] = torch.tensor(direction, dtype=torch.float64) + inverse_direction = itk.array_from_matrix(image.GetInverseDirection()) + inverse_direction_matrix = torch.eye(ndim + 1, dtype=torch.float64) + inverse_direction_matrix[:ndim, :ndim] = torch.tensor(inverse_direction, dtype=torch.float64) + + return direction_matrix, inverse_direction_matrix + + +def _compute_reference_space_affine_matrix(image, ref_image) -> torch.Tensor: + ndim = ref_image.ndim + + # Spacing and direction as matrices + spacing_matrix, inv_spacing_matrix = (m[:ndim, :ndim].numpy() for m in _compute_spacing_matrix(image)) + ref_spacing_matrix, ref_inv_spacing_matrix = (m[:ndim, :ndim].numpy() for m in _compute_spacing_matrix(ref_image)) + + direction_matrix, inv_direction_matrix = (m[:ndim, :ndim].numpy() for m in _compute_direction_matrix(image)) + ref_direction_matrix, ref_inv_direction_matrix = ( + m[:ndim, :ndim].numpy() for m in _compute_direction_matrix(ref_image) + ) + + # Matrix calculation + matrix = ref_direction_matrix @ ref_spacing_matrix @ inv_spacing_matrix @ inv_direction_matrix + + # Offset calculation + pixel_offset = -1 + image_size = np.asarray(ref_image.GetLargestPossibleRegion().GetSize(), np.float32) + translation = ( + (ref_direction_matrix @ ref_spacing_matrix - direction_matrix @ spacing_matrix) + @ (image_size + pixel_offset) + / 2 + ) + translation += np.asarray(ref_image.GetOrigin()) - np.asarray(image.GetOrigin()) + + # Convert matrix ITK matrix and translation to MONAI affine matrix + ref_affine_matrix = itk_to_monai_affine(image, matrix=matrix, translation=translation) + + return ref_affine_matrix + + +def monai_to_itk_ddf(image, ddf): + """ + converting the dense displacement field from the MONAI space to the ITK + Args: + image: itk image of array shape 2D: (H, W) or 3D: (D, H, W) + ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W) + Returns: + displacement_field: itk image of the corresponding displacement field + + """ + # 3, D, H, W -> D, H, W, 3 + ndim = image.ndim + ddf = ddf.transpose(tuple(list(range(1, ndim + 1)) + [0])) + # x, y, z -> z, x, y + ddf = ddf[..., ::-1] + + # Correct for spacing + spacing = np.asarray(image.GetSpacing(), dtype=np.float64) + ddf *= np.array(spacing, ndmin=ndim + 1) + + # Correct for direction + direction = np.asarray(image.GetDirection(), dtype=np.float64) + ddf = np.einsum("ij,...j->...i", direction, ddf, dtype=np.float64).astype(np.float32) + + # initialise displacement field - + vector_component_type = itk.F + vector_pixel_type = itk.Vector[vector_component_type, ndim] + displacement_field_type = itk.Image[vector_pixel_type, ndim] + displacement_field = itk.GetImageFromArray(ddf, ttype=displacement_field_type) + + # Set image metadata + displacement_field.SetSpacing(image.GetSpacing()) + displacement_field.SetOrigin(image.GetOrigin()) + displacement_field.SetDirection(image.GetDirection()) + + return displacement_field diff --git a/tests/test_itk_torch_bridge.py b/tests/test_itk_torch_bridge.py new file mode 100644 index 0000000000..c08db89198 --- /dev/null +++ b/tests/test_itk_torch_bridge.py @@ -0,0 +1,486 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.apps import download_url +from monai.data import ITKReader +from monai.data.itk_torch_bridge import ( + get_itk_image_center, + itk_image_to_metatensor, + itk_to_monai_affine, + metatensor_to_itk_image, + monai_to_itk_affine, + monai_to_itk_ddf, +) +from monai.networks.blocks import Warp +from monai.transforms import Affine +from monai.utils import optional_import, set_determinism +from tests.utils import skip_if_downloading_fails, skip_if_quick, test_is_quick, testing_data_config + +itk, has_itk = optional_import("itk") + +TESTS = ["CT_2D_head_fixed.mha", "CT_2D_head_moving.mha"] +if not test_is_quick(): + TESTS += ["copd1_highres_INSP_STD_COPD_img.nii.gz", "copd1_highres_EXP_STD_COPD_img.nii.gz"] + + +@unittest.skipUnless(has_itk, "Requires `itk` package.") +class TestITKTorchAffineMatrixBridge(unittest.TestCase): + def setUp(self): + set_determinism(seed=0) + self.data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + self.reader = ITKReader(pixel_type=itk.F) + + for file_name in TESTS: + path = os.path.join(self.data_dir, file_name) + if not os.path.exists(path): + with skip_if_downloading_fails(): + data_spec = testing_data_config("images", f"{file_name.split('.', 1)[0]}") + download_url( + data_spec["url"], path, hash_val=data_spec["hash_val"], hash_type=data_spec["hash_type"] + ) + + def tearDown(self): + set_determinism(seed=None) + + def create_itk_affine_from_parameters( + self, image, translation=None, rotation=None, scale=None, shear=None, center_of_rotation=None + ): + """ + Creates an affine transformation for an ITK image based on the provided parameters. + + Args: + image: The ITK image. + translation: The translation (shift) to apply to the image. + rotation: The rotation to apply to the image, specified as angles in radians around the x, y, and z axes. + scale: The scaling factor to apply to the image. + shear: The shear to apply to the image. + center_of_rotation: The center of rotation for the image. If not specified, + the center of the image is used. + + Returns: + A tuple containing the affine transformation matrix and the translation vector. + """ + itk_transform = itk.AffineTransform[itk.D, image.ndim].New() + + # Set center + if center_of_rotation: + itk_transform.SetCenter(center_of_rotation) + else: + itk_transform.SetCenter(get_itk_image_center(image)) + + # Set parameters + if rotation: + if image.ndim == 2: + itk_transform.Rotate2D(rotation[0]) + else: + for i, angle_in_rads in enumerate(rotation): + if angle_in_rads != 0: + axis = [0, 0, 0] + axis[i] = 1 + itk_transform.Rotate3D(axis, angle_in_rads) + + if scale: + itk_transform.Scale(scale) + + if shear: + itk_transform.Shear(*shear) + + if translation: + itk_transform.Translate(translation) + + matrix = np.asarray(itk_transform.GetMatrix(), dtype=np.float64) + + return matrix, translation + + def itk_affine_resample(self, image, matrix, translation, center_of_rotation=None, reference_image=None): + # Translation transform + itk_transform = itk.AffineTransform[itk.D, image.ndim].New() + + # Set center + if center_of_rotation: + itk_transform.SetCenter(center_of_rotation) + else: + itk_transform.SetCenter(get_itk_image_center(image)) + + # Set matrix and translation + itk_transform.SetMatrix(itk.matrix_from_array(matrix)) + itk_transform.Translate(translation) + + # Interpolator + image = image.astype(itk.D) + interpolator = itk.LinearInterpolateImageFunction.New(image) + + if not reference_image: + reference_image = image + + # Resample with ITK + output_image = itk.resample_image_filter( + image, interpolator=interpolator, transform=itk_transform, output_parameters_from_image=reference_image + ) + + return np.asarray(output_image, dtype=np.float32) + + def monai_affine_resample(self, metatensor, affine_matrix): + affine = Affine( + affine=affine_matrix, padding_mode="zeros", mode="bilinear", dtype=torch.float64, image_only=True + ) + output_tensor = affine(metatensor) + + return output_tensor.squeeze().permute(*torch.arange(output_tensor.ndim - 2, -1, -1)).array + + def remove_border(self, image): + """ + MONAI seems to have different behavior in the borders of the image than ITK. + This helper function sets the border of the ITK image as 0 (padding but keeping + the same image size) in order to allow numerical comparison between the + result from resampling with ITK/Elastix and resampling with MONAI. + To use: image[:] = remove_border(image) + Args: + image: The ITK image to be padded. + + Returns: + The padded array of data. + """ + return np.pad(image[1:-1, 1:-1, 1:-1] if image.ndim == 3 else image[1:-1, 1:-1], pad_width=1) + + def itk_warp(self, image, ddf): + """ + Warping with python itk + Args: + image: itk image of array shape 2D: (H, W) or 3D: (D, H, W) + ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W) + Returns: + warped_image: numpy array of shape (H, W) or (D, H, W) + """ + # MONAI -> ITK ddf + displacement_field = monai_to_itk_ddf(image, ddf) + + # Resample using the ddf + interpolator = itk.LinearInterpolateImageFunction.New(image) + warped_image = itk.warp_image_filter( + image, interpolator=interpolator, displacement_field=displacement_field, output_parameters_from_image=image + ) + + return np.asarray(warped_image) + + def monai_warp(self, image_tensor, ddf_tensor): + """ + Warping with MONAI + Args: + image_tensor: torch tensor of shape 2D: (1, 1, H, W) and 3D: (1, 1, D, H, W) + ddf_tensor: torch tensor of shape 2D: (1, 2, H, W) and 3D: (1, 3, D, H, W) + Returns: + warped_image: numpy array of shape (H, W) or (D, H, W) + """ + warp = Warp(mode="bilinear", padding_mode="zeros") + warped_image = warp(image_tensor.to(torch.float64), ddf_tensor.to(torch.float64)) + + return warped_image.to(torch.float32).squeeze().numpy() + + @parameterized.expand(TESTS) + def test_setting_affine_parameters(self, filepath): + # Read image + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + # Affine parameters + translation = [65.2, -50.2, 33.9][:ndim] + rotation = [0.78539816339, 1.0, -0.66][:ndim] + scale = [2.0, 1.5, 3.2][:ndim] + shear = [0, 1, 1.6] # axis1, axis2, coeff + + # Spacing + spacing = np.array([1.2, 1.5, 2.0])[:ndim] + image.SetSpacing(spacing) + + # ITK + matrix, translation = self.create_itk_affine_from_parameters(image, translation, rotation, scale, shear) + output_array_itk = self.itk_affine_resample(image, matrix=matrix, translation=translation) + + # MONAI + metatensor = itk_image_to_metatensor(image) + affine_matrix_for_monai = itk_to_monai_affine(image, matrix, translation) + output_array_monai = self.monai_affine_resample(metatensor, affine_matrix=affine_matrix_for_monai) + + # Make sure that the array conversion of the inputs is the same + input_array_monai = metatensor.squeeze().permute(*torch.arange(metatensor.ndim - 2, -1, -1)).array + np.testing.assert_array_equal(input_array_monai, np.asarray(image)) + + # Compare outputs + percentage = ( + 100 * np.isclose(output_array_monai, output_array_itk).sum(dtype=np.float64) / output_array_itk.size + ) + self.assertGreaterEqual(percentage, 99.0) + + @parameterized.expand(TESTS) + def test_arbitary_center_of_rotation(self, filepath): + # Read image + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + # ITK matrix (3x3 affine matrix) + matrix = np.array( + [ + [0.55915995, 0.50344867, 0.43208387], + [0.01133669, 0.82088571, 0.86841365], + [0.30478496, 0.94998986, 0.32742505], + ] + )[:ndim, :ndim] + translation = [54.0, 2.7, -11.9][:ndim] + + # Spatial properties + center_of_rotation = [-32.3, 125.1, 0.7][:ndim] + origin = [1.6, 0.5, 2.0][:ndim] + spacing = np.array([1.2, 1.5, 0.6])[:ndim] + + image.SetSpacing(spacing) + image.SetOrigin(origin) + + # ITK + output_array_itk = self.itk_affine_resample(image, matrix, translation, center_of_rotation) + + # MONAI + metatensor = itk_image_to_metatensor(image) + affine_matrix_for_monai = itk_to_monai_affine(image, matrix, translation, center_of_rotation) + output_array_monai = self.monai_affine_resample(metatensor, affine_matrix=affine_matrix_for_monai) + + # Make sure that the array conversion of the inputs is the same + input_array_monai = metatensor.squeeze().permute(*torch.arange(metatensor.ndim - 2, -1, -1)).array + np.testing.assert_array_equal(input_array_monai, np.asarray(image)) + + # Compare outputs + percentage = ( + 100 * np.isclose(output_array_monai, output_array_itk).sum(dtype=np.float64) / output_array_itk.size + ) + self.assertGreaterEqual(percentage, 99.0) + + @parameterized.expand(TESTS) + def test_monai_to_itk(self, filepath): + # Read image + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + # MONAI affine matrix + affine_matrix = torch.eye(ndim + 1, dtype=torch.float64) + affine_matrix[:ndim, :ndim] = torch.tensor( + [ + [0.55915995, 0.50344867, 0.43208387], + [0.01133669, 0.82088571, 0.86841365], + [0.30478496, 0.94998986, 0.32742505], + ], + dtype=torch.float64, + )[:ndim, :ndim] + + affine_matrix[:ndim, ndim] = torch.tensor([54.0, 2.7, -11.9], dtype=torch.float64)[:ndim] + + # Spatial properties + center_of_rotation = [-32.3, 125.1, 0.7][:ndim] + origin = [1.6, 0.5, 2.0][:ndim] + spacing = np.array([1.2, 1.5, 0.6])[:ndim] + + image.SetSpacing(spacing) + image.SetOrigin(origin) + + # ITK + matrix, translation = monai_to_itk_affine(image, affine_matrix, center_of_rotation) + output_array_itk = self.itk_affine_resample(image, matrix, translation, center_of_rotation) + + # MONAI + metatensor = itk_image_to_metatensor(image) + output_array_monai = self.monai_affine_resample(metatensor, affine_matrix) + + # Make sure that the array conversion of the inputs is the same + input_array_monai = metatensor.squeeze().permute(*torch.arange(metatensor.ndim - 2, -1, -1)).array + np.testing.assert_array_equal(input_array_monai, np.asarray(image)) + + # Compare outputs + percentage = ( + 100 * np.isclose(output_array_monai, output_array_itk).sum(dtype=np.float64) / output_array_itk.size + ) + self.assertGreaterEqual(percentage, 99.0) + + @parameterized.expand(TESTS) + def test_cyclic_conversion(self, filepath): + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + # ITK matrix (3x3 affine matrix) + matrix = np.array( + [ + [2.90971094, 1.18297296, 2.60008784], + [0.29416137, 0.10294283, 2.82302616], + [1.70578374, 1.39706003, 2.54652029], + ] + )[:ndim, :ndim] + + translation = [-29.05463245, 35.27116398, 48.58759597][:ndim] + + # Spatial properties + center_of_rotation = [-27.84789587, -60.7871084, 42.73501932][:ndim] + origin = [8.10416794, 5.4831944, 0.49211025][:ndim] + spacing = np.array([0.7, 3.2, 1.3])[:ndim] + + direction = np.array( + [ + [1.02895588, 0.22791448, 0.02429561], + [0.21927512, 1.28632268, -0.14932226], + [0.47455613, 0.38534345, 0.98505633], + ], + dtype=np.float64, + ) + image.SetDirection(direction[:ndim, :ndim]) + + image.SetSpacing(spacing) + image.SetOrigin(origin) + + affine_matrix = itk_to_monai_affine(image, matrix, translation, center_of_rotation) + matrix_result, translation_result = monai_to_itk_affine(image, affine_matrix, center_of_rotation) + + meta_tensor = itk_image_to_metatensor(image) + image_result = metatensor_to_itk_image(meta_tensor) + + np.testing.assert_allclose(matrix, matrix_result) + np.testing.assert_allclose(translation, translation_result) + np.testing.assert_array_equal(image.shape, image_result.shape) + np.testing.assert_array_equal(image, image_result) + + @parameterized.expand([(2,), (3,)]) + def test_random_array(self, ndim): + # Create image/array with random size and pixel intensities + s = torch.randint(low=2, high=20, size=(ndim,)) + img = 100 * torch.rand((1, 1, *s.tolist()), dtype=torch.float32) + + # Pad at the edges because ITK and MONAI have different behavior there + # during resampling + img = torch.nn.functional.pad(img, pad=ndim * (1, 1)) + ddf = 5 * torch.rand((1, ndim, *img.shape[-ndim:]), dtype=torch.float32) - 2.5 + + # Warp with MONAI + img_resampled = self.monai_warp(img, ddf) + + # Create ITK image + itk_img = itk.GetImageFromArray(img.squeeze().numpy()) + + # Set random spacing + spacing = 3 * np.random.rand(ndim) + itk_img.SetSpacing(spacing) + + # Set random direction + direction = 5 * np.random.rand(ndim, ndim) - 5 + direction = itk.matrix_from_array(direction) + itk_img.SetDirection(direction) + + # Set random origin + origin = 100 * np.random.rand(ndim) - 100 + itk_img.SetOrigin(origin) + + # Warp with ITK + itk_img_resampled = self.itk_warp(itk_img, ddf.squeeze().numpy()) + + # Compare + np.testing.assert_allclose(img_resampled, itk_img_resampled, rtol=1e-2, atol=1e-2) + + @parameterized.expand(TESTS) + @skip_if_quick + def test_real_data(self, filepath): + # Read image + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + # Random ddf + ddf = 10 * torch.rand((1, ndim, *image.shape), dtype=torch.float32) - 10 + + # Warp with MONAI + image_tensor = torch.tensor(itk.GetArrayFromImage(image), dtype=torch.float32).unsqueeze(0).unsqueeze(0) + img_resampled = self.monai_warp(image_tensor, ddf) + + # Warp with ITK + itk_img_resampled = self.itk_warp(image, ddf.squeeze().numpy()) + + # Compare + np.testing.assert_allclose(img_resampled, itk_img_resampled, rtol=1e-3, atol=1e-3) + + @parameterized.expand(zip(TESTS[::2], TESTS[1::2])) + @skip_if_quick + def test_use_reference_space(self, ref_filepath, filepath): + # Read the images + image = self.reader.read(os.path.join(self.data_dir, filepath)) + image[:] = self.remove_border(image) + ndim = image.ndim + + ref_image = self.reader.read(os.path.join(self.data_dir, ref_filepath)) + + # Set arbitary origin, spacing, direction for both of the images + image.SetSpacing([1.2, 2.0, 1.7][:ndim]) + ref_image.SetSpacing([1.9, 1.5, 1.3][:ndim]) + + direction = np.array( + [ + [1.02895588, 0.22791448, 0.02429561], + [0.21927512, 1.28632268, -0.14932226], + [0.47455613, 0.38534345, 0.98505633], + ], + dtype=np.float64, + ) + image.SetDirection(direction[:ndim, :ndim]) + + ref_direction = np.array( + [ + [1.26032417, -0.19243174, 0.54877414], + [0.31958275, 0.9543068, 0.2720827], + [-0.24106769, -0.22344502, 0.9143302], + ], + dtype=np.float64, + ) + ref_image.SetDirection(ref_direction[:ndim, :ndim]) + + image.SetOrigin([57.3, 102.0, -20.9][:ndim]) + ref_image.SetOrigin([23.3, -0.5, 23.7][:ndim]) + + # Set affine parameters + matrix = np.array( + [ + [0.55915995, 0.50344867, 0.43208387], + [0.01133669, 0.82088571, 0.86841365], + [0.30478496, 0.94998986, 0.32742505], + ] + )[:ndim, :ndim] + translation = [54.0, 2.7, -11.9][:ndim] + center_of_rotation = [-32.3, 125.1, 0.7][:ndim] + + # Resample using ITK + output_array_itk = self.itk_affine_resample(image, matrix, translation, center_of_rotation, ref_image) + + # MONAI + metatensor = itk_image_to_metatensor(image) + affine_matrix_for_monai = itk_to_monai_affine(image, matrix, translation, center_of_rotation, ref_image) + output_array_monai = self.monai_affine_resample(metatensor, affine_matrix_for_monai) + + # Compare outputs + np.testing.assert_allclose(output_array_monai, output_array_itk, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index 788d664439..c2d2ba9635 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -54,6 +54,26 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MNI152_T1_2mm_strucseg.nii.gz", "hash_type": "sha256", "hash_val": "eb4f1e596ca85aadaefc359d409fb9a3e27d733e6def04b996953b7c54bc26d4" + }, + "copd1_highres_INSP_STD_COPD_img": { + "url": "https://data.kitware.com/api/v1/file/62a0f067bddec9d0c4175c5a/download", + "hash_type": "sha512", + "hash_val": "60193cd6ef0cf055c623046446b74f969a2be838444801bd32ad5bedc8a7eeecb343e8a1208769c9c7a711e101c806a3133eccdda7790c551a69a64b9b3701e9" + }, + "copd1_highres_EXP_STD_COPD_img": { + "url": "https://data.kitware.com/api/v1/item/62a0f045bddec9d0c4175c44/download", + "hash_type": "sha512", + "hash_val": "841ef303958541474e66c2d1ccdc8b7ed17ba2f2681101307766b979a07979f2ec818ddf13791c3f1ac5a8ec3258d6ea45b692b4b4a838de9188602618972b6d" + }, + "CT_2D_head_fixed": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CT_2D_head_fixed.mha", + "hash_type": "sha256", + "hash_val": "06f2ce6fbf6a59f0874c735555fcf71717f631156b1b0697c1752442f7fc1cc5" + }, + "CT_2D_head_moving": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CT_2D_head_moving.mha", + "hash_type": "sha256", + "hash_val": "a37c5fe388c38b3f4ac564f456277d09d3982eda58c4da05ead8ee2332360f47" } }, "videos": {