From 59418aa4c71c3de10b7ab4b91d2662ba59cba84b Mon Sep 17 00:00:00 2001 From: Tom Birdsong Date: Mon, 25 Apr 2022 12:53:06 -0400 Subject: [PATCH] ENH: Add VkMultiResolutionPyramidImageFilter Adds image filter for image pyramid generation making selective use of VkFFT acceleration for FFT-based blurring. Also adds factory override class so that `VkMultiResolutionPyramidImageFilter` can be used as the default implementation of `MultiResolutionPyramidImageFilter` per the ITK object factory. --- .../itkVkMultiResolutionPyramidImageFilter.h | 195 +++++++++++++++++ ...itkVkMultiResolutionPyramidImageFilter.hxx | 205 ++++++++++++++++++ ...MultiResolutionPyramidImageFilterFactory.h | 114 ++++++++++ itk-module.cmake | 7 +- ...solutionPyramidImageFilterTest0.mha.sha512 | 1 + ...solutionPyramidImageFilterTest1.mha.sha512 | 1 + ...solutionPyramidImageFilterTest2.mha.sha512 | 1 + ...solutionPyramidImageFilterTest3.mha.sha512 | 1 + ...solutionPyramidImageFilterTest4.mha.sha512 | 1 + test/CMakeLists.txt | 34 +++ ...esolutionPyramidImageFilterFactoryTest.cxx | 56 +++++ ...kMultiResolutionPyramidImageFilterTest.cxx | 173 +++++++++++++++ ...tkVkMultiResolutionPyramidImageFilter.wrap | 3 + ...tiResolutionPyramidImageFilterFactory.wrap | 1 + 14 files changed, 790 insertions(+), 3 deletions(-) create mode 100644 include/itkVkMultiResolutionPyramidImageFilter.h create mode 100644 include/itkVkMultiResolutionPyramidImageFilter.hxx create mode 100644 include/itkVkMultiResolutionPyramidImageFilterFactory.h create mode 100644 test/Baseline/itkVkMultiResolutionPyramidImageFilterTest0.mha.sha512 create mode 100644 test/Baseline/itkVkMultiResolutionPyramidImageFilterTest1.mha.sha512 create mode 100644 test/Baseline/itkVkMultiResolutionPyramidImageFilterTest2.mha.sha512 create mode 100644 test/Baseline/itkVkMultiResolutionPyramidImageFilterTest3.mha.sha512 create mode 100644 test/Baseline/itkVkMultiResolutionPyramidImageFilterTest4.mha.sha512 create mode 100644 test/itkVkMultiResolutionPyramidImageFilterFactoryTest.cxx create mode 100644 test/itkVkMultiResolutionPyramidImageFilterTest.cxx create mode 100644 wrapping/itkVkMultiResolutionPyramidImageFilter.wrap create mode 100644 wrapping/itkVkMultiResolutionPyramidImageFilterFactory.wrap diff --git a/include/itkVkMultiResolutionPyramidImageFilter.h b/include/itkVkMultiResolutionPyramidImageFilter.h new file mode 100644 index 00000000..63eea5b6 --- /dev/null +++ b/include/itkVkMultiResolutionPyramidImageFilter.h @@ -0,0 +1,195 @@ +/*========================================================================= + * + * Copyright NumFOCUS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef itkVkMultiResolutionPyramidImageFilter_h +#define itkVkMultiResolutionPyramidImageFilter_h + +#include "itkMultiResolutionPyramidImageFilter.h" + +#include "itkDiscreteGaussianImageFilter.h" +#include "itkFFTDiscreteGaussianImageFilter.h" +#include "itkVector.h" +#include "itkMacro.h" +#include "VkFFTBackendExport.h" + +#include + +namespace itk +{ + +/** \class VkMultiResolutionPyramidImageFilter + * \brief Creates a multi-resolution pyramid with FFT acceleration + * + * VkMultiResolutionPyramidImageFilter re-implements a framework + * for creating an image pyramid as laid out in + * MultiResolutionPyramidImageFilter. Conditional logic is added + * to preemptively select the optimal image smoothing pipeline + * that is expected to give the best performance for different + * pyramid levels. + * + * Separable spatial convolution with DiscreteGaussianImageFilter + * runs quickly for small kernel sizes but scales poorly with + * increasing kernel size. By contrast ITK FFT convolution accelerated + * with a VkFFT GPU backend scales slowly with increasing kernel size + * but is typically outperformed by spatial convolution filters + * for small kernel sizes. + * + * VkMultiResolutionPyramidImageFilter allows the user to fix the + * metric threshold at which a performance tradeoff is expected + * between spatial and FFT convolution. The exact threshold depends + * on user hardware and can be estimated through benchmarking with + * scripts in the ITKVkFFTBackend repository. + * + * By mitigating blurring times on levels with large kernel sizes + * VkMultiResolutionPyramidImageFilter has been observed to run in + * as little as 50% of the time of its base class. + * + * See documentation of MultiResolutionPyramidImageFilter + * for information on how to specify a multi-resolution schedule. + * + * \sa MultiResolutionPyramidImageFilter + * \sa DiscreteGaussianImageFilter + * \sa FFTDiscreteGaussianImageFilter + * \sa ShrinkImageFilter + * + * \ingroup VkFFTBackend + * \ingroup PyramidImageFilter + * \ingroup ITKRegistrationCommon + */ +template +class ITK_TEMPLATE_EXPORT VkMultiResolutionPyramidImageFilter + : public MultiResolutionPyramidImageFilter +{ +public: + ITK_DISALLOW_COPY_AND_MOVE(VkMultiResolutionPyramidImageFilter); + + /** Standard class type aliases. */ + using Self = VkMultiResolutionPyramidImageFilter; + using Superclass = MultiResolutionPyramidImageFilter; + using Pointer = SmartPointer; + using ConstPointer = SmartPointer; + + /** Method for creation through the object factory. */ + itkNewMacro(Self); + + /** Run-time type information (and related methods). */ + itkTypeMacro(VkMultiResolutionPyramidImageFilter, MultiResolutionPyramidImageFilter); + + /** ImageDimension enumeration. */ + static constexpr unsigned int ImageDimension = TInputImage::ImageDimension; + + /** Inherit types from Superclass. */ + using typename Superclass::InputImageType; + using typename Superclass::OutputImageType; + using typename Superclass::InputImagePointer; + using typename Superclass::OutputImagePointer; + using typename Superclass::InputImageConstPointer; + using InputSizeType = typename InputImageType::SizeType; + using OutputPixelType = typename OutputImageType::PixelType; + using OutputSizeType = typename OutputImageType::SizeType; + using typename Superclass::ScheduleType; + + using VarianceType = itk::Vector; + using KernelSizeType = OutputSizeType; + + /** Types for acceleration. + * Assumes and does not verify that FFT backend is accelerated. */ + using BaseSmootherType = DiscreteGaussianImageFilter; + using SpatialSmootherType = DiscreteGaussianImageFilter; + using FFTSmootherType = FFTDiscreteGaussianImageFilter; + + /** Set the metric threshold to decide between + * accelerated methods such as CPU-based separable smoothing + * versus GPU-based FFT smoothing. + * We can predictively compare spatial and FFT smoothing + * performance using the following metric: + * + * f(i,j,k,x,y,z) = log((i + j + k) * x * y * z) + * + * where i,j,k are the dimensions of the kernel for a given + * pyramid level and x,y,z are the dimensions of the + * output image region. + * + * The equation above approximates the difference in runtime complexity + * between separable spatial Gaussian smoothing and FFT Gaussian smoothing. + * Under separable smoothing each pixel [xi,yi,zi] is used in computation + * approximately (i + j + k) times. FFT smoothing meanwhile has significant + * overhead in setup but scales much more slowly with kernel and image sizes. + * As a result there is an approximate threshold where GPU-accelerated + * smoothing outperforms spatial smoothing for a given pyramid level. + * + * The default threshold value 8.0 has been empirically determined as + * a reasonable approximation such that f(...) < 8.0 indicates that + * spatial convolution will run faster while f(...) > 8.0 indicates that + * FFT convolution will run faster. The threshold value is not universal + * and may need to be adjusted to better match benchmarking results for + * particular hardware and expected image sizes so that nuances such as + * multithreading and GPU performance may be taken into account. + */ + itkSetMacro(MetricThreshold, float); + itkGetMacro(MetricThreshold, float); + + /** Set the metric threshold from a certain parameter set describing the input size + * and kernel radius threshold that is expected to be equally fast with separable + * spatial smoothing and FFT smoothing */ + void + SetMetricThreshold(const InputSizeType & inputSize, const KernelSizeType & kernelRadius) + { + this->SetMetricThreshold(ComputeMetricValue(inputSize, kernelRadius)); + } + + float + ComputeMetricValue(const InputSizeType & inputSize, const KernelSizeType & kernelRadius) const; + + /** Estimate the kernel radius from ilevel settings */ + KernelSizeType + GetKernelRadius(unsigned int ilevel) const; + + /** Get the kernel variance for the given pyramid level + * based on the current schedule */ + VarianceType + GetVariance(unsigned int ilevel) const; + + /** Get whether FFT smoothing will be used for the given + * pyramid level */ + bool + GetUseFFT(const KernelSizeType & kernelRadius) const; + +protected: + VkMultiResolutionPyramidImageFilter() = default; + ~VkMultiResolutionPyramidImageFilter() override = default; + + /** Generate the output data. */ + void + GenerateData() override; + + void + PrintSelf(std::ostream & os, Indent indent) const override; + +private: + float m_MetricThreshold = 8.0f; + typename SpatialSmootherType::Pointer spatialSmoother = SpatialSmootherType::New(); + typename FFTSmootherType::Pointer fftSmoother = FFTSmootherType::New(); + +}; +} // namespace itk + +#ifndef ITK_MANUAL_INSTANTIATION +# include "itkVkMultiResolutionPyramidImageFilter.hxx" +#endif + +#endif diff --git a/include/itkVkMultiResolutionPyramidImageFilter.hxx b/include/itkVkMultiResolutionPyramidImageFilter.hxx new file mode 100644 index 00000000..3cd5de1a --- /dev/null +++ b/include/itkVkMultiResolutionPyramidImageFilter.hxx @@ -0,0 +1,205 @@ +/*========================================================================= + * + * Copyright NumFOCUS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef itkVkMultiResolutionPyramidImageFilter_hxx +#define itkVkMultiResolutionPyramidImageFilter_hxx + +#include "itkCastImageFilter.h" +#include "itkDiscreteGaussianImageFilter.h" +#include "itkGaussianOperator.h" +#include "itkMacro.h" +#include "itkResampleImageFilter.h" +#include "itkShrinkImageFilter.h" +#include "itkIdentityTransform.h" + +#include "itkMath.h" + +namespace itk +{ +template +void +VkMultiResolutionPyramidImageFilter::GenerateData() +{ + // Mostly reimplements MultiResolutionPyramidImageFilter::GenerateData + + // Get the input and output pointers + InputImageConstPointer inputPtr = this->GetInput(); + + // Create caster, smoother and resampleShrinker filters + using CasterType = CastImageFilter; + + using ImageToImageType = ImageToImageFilter; + using ResampleShrinkerType = ResampleImageFilter; + using ShrinkerType = ShrinkImageFilter; + + auto caster = CasterType::New(); + typename BaseSmootherType::Pointer smoother; + + typename ImageToImageType::Pointer shrinkerFilter; + // + // only one of these pointers is going to be valid, depending on the + // value of UseShrinkImageFilter flag + typename ResampleShrinkerType::Pointer resampleShrinker; + typename ShrinkerType::Pointer shrinker; + + if (this->GetUseShrinkImageFilter()) + { + shrinker = ShrinkerType::New(); + shrinkerFilter = shrinker.GetPointer(); + } + else + { + resampleShrinker = ResampleShrinkerType::New(); + using LinearInterpolatorType = itk::LinearInterpolateImageFunction; + auto interpolator = LinearInterpolatorType::New(); + resampleShrinker->SetInterpolator(interpolator); + resampleShrinker->SetDefaultPixelValue(0); + shrinkerFilter = resampleShrinker.GetPointer(); + } + // Setup the filters + caster->SetInput(inputPtr); + + unsigned int ilevel, idim; + unsigned int factors[ImageDimension]; + VarianceType variance; + + for (ilevel = 0; ilevel < this->m_NumberOfLevels; ++ilevel) + { + this->UpdateProgress(static_cast(ilevel) / static_cast(this->m_NumberOfLevels)); + + // Allocate memory for each output + OutputImagePointer outputPtr = this->GetOutput(ilevel); + outputPtr->SetBufferedRegion(outputPtr->GetRequestedRegion()); + outputPtr->Allocate(); + + // compute shrink factors + for (idim = 0; idim < ImageDimension; ++idim) + { + factors[idim] = this->m_Schedule[ilevel][idim]; + } + + if (!this->GetUseShrinkImageFilter()) + { + using IdentityTransformType = itk::IdentityTransform; + auto identityTransform = IdentityTransformType::New(); + resampleShrinker->SetOutputParametersFromImage(outputPtr); + resampleShrinker->SetTransform(identityTransform); + } + else + { + shrinker->SetShrinkFactors(factors); + } + + // select spatial or FFT smoothing based on user threshold settings + // to maximize anticipated performance + if (GetUseFFT(this->GetKernelRadius(ilevel))) + { + smoother = static_cast(fftSmoother); + } + else + { + smoother = static_cast(spatialSmoother); + } + + // Set up smoothing filter + smoother->SetUseImageSpacing(false); + smoother->SetInput(caster->GetOutput()); + smoother->SetMaximumError(this->m_MaximumError); + variance = this->GetVariance(ilevel); + smoother->SetVariance(variance); + shrinkerFilter->SetInput(smoother->GetOutput()); + + shrinkerFilter->GraftOutput(outputPtr); + + // force to always update in case shrink factors are the same + shrinkerFilter->Modified(); + shrinkerFilter->UpdateLargestPossibleRegion(); + this->GraftNthOutput(ilevel, shrinkerFilter->GetOutput()); + } +} + +template +float +VkMultiResolutionPyramidImageFilter::ComputeMetricValue( + const InputSizeType & inputSize, + const KernelSizeType & kernelRadius) const +{ + unsigned int totalKernelSize = 0; + float metricValue = 1.0f; + for (unsigned int dim = 0; dim < ImageDimension; ++dim) + { + totalKernelSize += kernelRadius[dim] * 2 + 1; + metricValue *= inputSize[dim]; + } + metricValue *= totalKernelSize; + metricValue = std::log10(metricValue); + return metricValue; +} + +template +bool +VkMultiResolutionPyramidImageFilter::GetUseFFT(const KernelSizeType & kernelRadius) const +{ + auto requestedSize = this->GetInput()->GetRequestedRegion().GetSize(); + auto metricValue = this->ComputeMetricValue(requestedSize, kernelRadius); + return metricValue > m_MetricThreshold; +} + +template +typename VkMultiResolutionPyramidImageFilter::KernelSizeType +VkMultiResolutionPyramidImageFilter::GetKernelRadius(unsigned int ilevel) const +{ + using OperatorType = itk::GaussianOperator; + auto * oper = new OperatorType; + KernelSizeType radius; + for (unsigned int dim = 0; dim < ImageDimension; ++dim) + { + oper->SetDirection(dim); + oper->SetMaximumError(this->m_MaximumError); + oper->SetVariance(this->GetVariance(ilevel)[dim]); + oper->CreateDirectional(); + radius[dim] = oper->GetRadius()[dim]; + } + return radius; +} + +template +typename VkMultiResolutionPyramidImageFilter::VarianceType +VkMultiResolutionPyramidImageFilter::GetVariance(unsigned int ilevel) const +{ + VarianceType variance; + for (unsigned int dim = 0; dim < ImageDimension; ++dim) + { + variance[dim] = itk::Math::sqr(0.5 * static_cast(this->m_Schedule[ilevel][dim])); + } + return variance; +} + +/** + * PrintSelf method + */ +template +void +VkMultiResolutionPyramidImageFilter::PrintSelf(std::ostream & os, Indent indent) const +{ + Superclass::PrintSelf(os, indent); + + os << indent << "Kernel/image size metric threshold: " << m_MetricThreshold << std::endl; +} +} // namespace itk + +#endif // itkVkMultiResolutionPyramidImageFilter_hxx diff --git a/include/itkVkMultiResolutionPyramidImageFilterFactory.h b/include/itkVkMultiResolutionPyramidImageFilterFactory.h new file mode 100644 index 00000000..a43708bc --- /dev/null +++ b/include/itkVkMultiResolutionPyramidImageFilterFactory.h @@ -0,0 +1,114 @@ +/*========================================================================= + * + * Copyright NumFOCUS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ + +#ifndef itkVkMultiResolutionPyramidImageFilterFactory_h +#define itkVkMultiResolutionPyramidImageFilterFactory_h +#include "VkFFTBackendExport.h" + +#include "itkVkMultiResolutionPyramidImageFilter.h" +#include "itkImage.h" +#include "itkObjectFactoryBase.h" +#include "itkVersion.h" + +namespace itk +{ +/** \class VkMultiResolutionPyramidImageFilterFactory + * + * \brief Object Factory implementation for overriding + * MultiResolutionPyramidImageFilterFactory with VkMultiResolutionPyramidImageFilterFactory + * + * \sa ObjectFactoryBase + * \sa MultiResolutionPyramidImageFilter + * \sa VkMultiResolutionPyramidImageFilter + * + * \ingroup VkFFTBackend + * \ingroup ITKRegistration + * \ingroup FourierTransform + * \ingroup ITKFFT + */ +class VkMultiResolutionPyramidImageFilterFactory : public itk::ObjectFactoryBase +{ +public: + ITK_DISALLOW_COPY_AND_MOVE(VkMultiResolutionPyramidImageFilterFactory); + + using Self = VkMultiResolutionPyramidImageFilterFactory; + using Superclass = ObjectFactoryBase; + using Pointer = SmartPointer; + using ConstPointer = SmartPointer; + + /** Class methods used to interface with the registered factories. */ + const char * + GetITKSourceVersion() const override + { + return ITK_SOURCE_VERSION; + } + const char * + GetDescription() const override + { + return "An VkMultiResolutionPyramidImageFilterFactory factory"; + } + + /** Method for class instantiation. */ + itkFactorylessNewMacro(Self); + + /** Run-time type information (and related methods). */ + itkTypeMacro(VkMultiResolutionPyramidImageFilterFactory, itk::ObjectFactoryBase); + + /** Register one factory of this type */ + static void + RegisterOneFactory() + { + VkMultiResolutionPyramidImageFilterFactory::Pointer factory = VkMultiResolutionPyramidImageFilterFactory::New(); + + ObjectFactoryBase::RegisterFactoryInternal(factory); + } + +protected: + /** Override base MultiResolutionPyramidImageFilter constructor at runtime to return + * an upcast VkMultiResolutionPyramidImageFilter instance through the object factory + */ + template + void + OverrideSuperclassType(const std::integer_sequence &) + { + using InputImageType = Image; + using OutputImageType = Image; + this->RegisterOverride( + typeid(typename VkMultiResolutionPyramidImageFilter::Superclass).name(), + typeid(VkMultiResolutionPyramidImageFilter).name(), + "VkMultiResolutionPyramidImageFilter Override", + true, + CreateObjectFunction>::New()); + OverrideSuperclassType(std::integer_sequence{}); + } + template + void + OverrideSuperclassType(const std::integer_sequence &) + {} + + VkMultiResolutionPyramidImageFilterFactory() + { + OverrideSuperclassType(std::integer_sequence{}); + + OverrideSuperclassType(std::integer_sequence{}); + } +}; + +} // namespace itk + +#endif // itkVkMultiResolutionPyramidImageFilterFactory_h diff --git a/itk-module.cmake b/itk-module.cmake index a676056c..e339478a 100644 --- a/itk-module.cmake +++ b/itk-module.cmake @@ -5,8 +5,6 @@ file(READ "${MY_CURRENT_DIR}/README.rst" DOCUMENTATION) # itk_module() defines the module dependencies in VkFFTBackend # VkFFTBackend depends on ITKCommon -# The testing module in VkFFTBackend depends on ITKTestKernel -# and ITKMetaIO(besides VkFFTBackend and ITKCore) # By convention those modules outside of ITK are not prefixed with # ITK. @@ -16,11 +14,14 @@ itk_module(VkFFTBackend ITKCommon ITKStatistics ITKFFT + ITKRegistrationCommon + ITKConvolution COMPILE_DEPENDS ITKImageSources + ITKSmoothing TEST_DEPENDS ITKTestKernel - ITKMetaIO + ITKIOImageBase ITKImageCompose ITKImageIntensity DESCRIPTION diff --git a/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest0.mha.sha512 b/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest0.mha.sha512 new file mode 100644 index 00000000..e20523e5 --- /dev/null +++ b/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest0.mha.sha512 @@ -0,0 +1 @@ +5295f0968db523dce4a2732922d93de936a64e686613451cdd3a2192fb435246a712da42eb3e7454ffc4b04ec3139f14592b704f18aa52f60777855da5a54715 diff --git a/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest1.mha.sha512 b/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest1.mha.sha512 new file mode 100644 index 00000000..bece0a14 --- /dev/null +++ b/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest1.mha.sha512 @@ -0,0 +1 @@ +ec80de62f6dccbd81168095034a17902fd6a2050b512a4274f09c7797de0833d08dad39c8cba2ce6ba9a88aa14e9c6895a9566c53be96d1757d4562e9f112ceb diff --git a/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest2.mha.sha512 b/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest2.mha.sha512 new file mode 100644 index 00000000..3afa2fb1 --- /dev/null +++ b/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest2.mha.sha512 @@ -0,0 +1 @@ +3c8ef474f8e359c6a24d84892aec7a332763d46669bdf6b2084da2d202522f1109de0699954260d9251b29608e8d2ebd69fc23cb1aecbd88dbc380a54005466b diff --git a/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest3.mha.sha512 b/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest3.mha.sha512 new file mode 100644 index 00000000..315df77d --- /dev/null +++ b/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest3.mha.sha512 @@ -0,0 +1 @@ +55b81639dd2bfc5ffc143f04ffa397a6098081ee3d0f1c0c2449cc7e0f3eafd6177a854185cde1a9dc1e841637de3282e697d86778433ff6d2a7fb59ccf227dd diff --git a/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest4.mha.sha512 b/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest4.mha.sha512 new file mode 100644 index 00000000..460903f8 --- /dev/null +++ b/test/Baseline/itkVkMultiResolutionPyramidImageFilterTest4.mha.sha512 @@ -0,0 +1 @@ +05e178cecc133a3a4175a26049b4402c52328cd849543930c991d471c4bdf46b05202af7e3b309a8a62ede7f12896ab423ef1d1de8df9e41b92bf4a71176c980 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c86544b0..faac88c1 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -11,6 +11,8 @@ set(VkFFTBackendTests itkVkGlobalConfigurationTest.cxx itkVkHalfHermitianFFTImageFilterTest.cxx itkVkInverse1DFFTImageFilterBaselineTest.cxx + itkVkMultiResolutionPyramidImageFilterTest.cxx + itkVkMultiResolutionPyramidImageFilterFactoryTest.cxx ) include_directories(${VkFFTBackend_INCLUDE_DIRS}) @@ -95,3 +97,35 @@ itk_add_test(NAME itkVkFFTImageFilterFactoryTest itk_add_test(NAME itkVkGlobalConfigurationTest COMMAND VkFFTBackendTestDriver itkVkGlobalConfigurationTest) + +itk_add_test(NAME itkVkMultiResolutionPyramidImageFilterTest + COMMAND VkFFTBackendTestDriver + --compare + DATA{Baseline/itkVkMultiResolutionPyramidImageFilterTest0.mha} + ${ITK_TEST_OUTPUT_DIR}/itkVkMultiResolutionPyramidImageFilterTest0.mha + --compare + DATA{Baseline/itkVkMultiResolutionPyramidImageFilterTest1.mha} + ${ITK_TEST_OUTPUT_DIR}/itkVkMultiResolutionPyramidImageFilterTest1.mha + --compare + DATA{Baseline/itkVkMultiResolutionPyramidImageFilterTest2.mha} + ${ITK_TEST_OUTPUT_DIR}/itkVkMultiResolutionPyramidImageFilterTest2.mha + --compare + DATA{Baseline/itkVkMultiResolutionPyramidImageFilterTest3.mha} + ${ITK_TEST_OUTPUT_DIR}/itkVkMultiResolutionPyramidImageFilterTest3.mha + --compare + DATA{Baseline/itkVkMultiResolutionPyramidImageFilterTest4.mha} + ${ITK_TEST_OUTPUT_DIR}/itkVkMultiResolutionPyramidImageFilterTest4.mha + itkVkMultiResolutionPyramidImageFilterTest + DATA{Input/TreeBarkTexture.png} + ${ITK_TEST_OUTPUT_DIR}/itkVkMultiResolutionPyramidImageFilterTest + 10 # kernelRadiusThreshold dim 0 + 12 # kernelRadiusThreshold dim 1 + 1 # threshold dimension + 0 # useShrinkFilter + 5 # numLevels +) + +itk_add_test(NAME itkVkMultiResolutionPyramidImageFilterFactoryTest + COMMAND VkFFTBackendTestDriver + itkVkMultiResolutionPyramidImageFilterFactoryTest + ) diff --git a/test/itkVkMultiResolutionPyramidImageFilterFactoryTest.cxx b/test/itkVkMultiResolutionPyramidImageFilterFactoryTest.cxx new file mode 100644 index 00000000..d9407aea --- /dev/null +++ b/test/itkVkMultiResolutionPyramidImageFilterFactoryTest.cxx @@ -0,0 +1,56 @@ +/*========================================================================= + * + * Copyright NumFOCUS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ + +#include +#include + +#include "itkMultiResolutionPyramidImageFilter.h" +#include "itkVkMultiResolutionPyramidImageFilter.h" + +#include "itkVkMultiResolutionPyramidImageFilterFactory.h" +#include "itkTestingMacros.h" + +// Verify MultiResolutionPyramidImageFilter can be overriden +// with spatial+FFT implementation through object factory override + +int +itkVkMultiResolutionPyramidImageFilterFactoryTest(int, char *[]) +{ + using PixelType = double; + constexpr unsigned int Dimension{ 2 }; + using ImageType = itk::Image; + using BaseFilterType = itk::MultiResolutionPyramidImageFilter; + using VkSubclassType = itk::VkMultiResolutionPyramidImageFilter; + + // Verify default is non-accelerated implementation + typename BaseFilterType::Pointer baseFilter = BaseFilterType::New(); + VkSubclassType * derivedFilter = dynamic_cast(baseFilter.GetPointer()); + ITK_TEST_EXPECT_TRUE(derivedFilter == nullptr); + ITK_EXERCISE_BASIC_OBJECT_METHODS(baseFilter, MultiResolutionPyramidImageFilter, ImageToImageFilter); + + // Register factory and verify override + itk::VkMultiResolutionPyramidImageFilterFactory::RegisterOneFactory(); + + baseFilter = BaseFilterType::New(); + derivedFilter = dynamic_cast(baseFilter.GetPointer()); + ITK_TEST_EXPECT_TRUE(derivedFilter != nullptr); + ITK_EXERCISE_BASIC_OBJECT_METHODS( + derivedFilter, VkMultiResolutionPyramidImageFilter, MultiResolutionPyramidImageFilter); + + return EXIT_SUCCESS; +} diff --git a/test/itkVkMultiResolutionPyramidImageFilterTest.cxx b/test/itkVkMultiResolutionPyramidImageFilterTest.cxx new file mode 100644 index 00000000..d1657c1f --- /dev/null +++ b/test/itkVkMultiResolutionPyramidImageFilterTest.cxx @@ -0,0 +1,173 @@ +/*========================================================================= + * + * Copyright NumFOCUS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ + +#include +#include + +#include "itkVkMultiResolutionPyramidImageFilter.h" +#include "itkImageFileReader.h" +#include "itkImageFileWriter.h" +#include "itkMath.h" +#include "itkTestingMacros.h" + +namespace +{ +// The following three classes are used to support callbacks +// on the filter in the pipeline that follows later +class ShowProgressObject +{ +public: + ShowProgressObject(itk::ProcessObject * o) { m_Process = o; } + void + ShowProgress() + { + std::cout << "Progress " << m_Process->GetProgress() << std::endl; + } + itk::ProcessObject::Pointer m_Process; +}; +} // namespace + +int +itkVkMultiResolutionPyramidImageFilterTest(int argc, char * argv[]) +{ + if (argc < 3) + { + std::cerr << "Missing Parameters." << std::endl; + std::cerr << "Usage: " << itkNameOfTestExecutableMacro(argv); + std::cerr << " inputImage outputImage [threshold1] [kernelThresholdDimension] [useShrinkFilter] " + "[numLevels] [expectedFFTLevelCount]" + << std::endl; + std::cerr << std::flush; + return EXIT_FAILURE; + } + + constexpr unsigned int ImageDimension = 2; + using InputPixelType = float; + using ImageType = itk::Image; + + auto inputImage = itk::ReadImage(argv[1]); + + using PyramidType = itk::VkMultiResolutionPyramidImageFilter; + using ScheduleType = typename PyramidType::ScheduleType; + using KernelSizeType = typename PyramidType::KernelSizeType; + + KernelSizeType kernelRadiusThreshold; + if (argc == 4) + { + kernelRadiusThreshold.Fill(std::atoi(argv[3])); + } + else if (argc > 4) + { + kernelRadiusThreshold[0] = std::atoi(argv[3]); + kernelRadiusThreshold[1] = std::atoi(argv[4]); + } + else + { + kernelRadiusThreshold.Fill(10); + } + + auto kernelThresholdDimension = (argc > 5 ? std::atoi(argv[5]) : 1); + bool useShrinkFilter = (argc > 6 && std::atoi(argv[6]) == 1); + unsigned int numLevels = (argc > 7 ? std::atoi(argv[7]) : 3); + int expectedFFTCount = (argc > 8 ? std::atoi(argv[8]) : -1); // only test if specified + + // Set up multi-resolution pyramid + auto pyramidFilter = PyramidType::New(); + pyramidFilter->SetInput(inputImage); + + pyramidFilter->SetUseShrinkImageFilter(useShrinkFilter); + ITK_TEST_SET_GET_VALUE(pyramidFilter->GetUseShrinkImageFilter(), useShrinkFilter); + + // Verify metric threshold + // Tune performance based on expected image sizes, underlying acceleration hardware, etc + // If a desired threshold value is known then allow the user to set it directly. + // This value can be obtained from user benchmarking applied to their specific use case, + // i.e. specific hardware, expected image sizes, pyramid kernel sizes, etc + pyramidFilter->SetMetricThreshold(5.2f); + ITK_TEST_SET_GET_VALUE(pyramidFilter->GetMetricThreshold(), 5.2f); + + // Tune performance to an expected data profile + // Here we specify that, for any input that is the size of our input image, if a pyramid level + // uses a Gaussian kernel of radius greater than 6 then FFT smoothing should be used, + // otherwise spatial smoothing should be used + pyramidFilter->SetMetricThreshold(inputImage->GetLargestPossibleRegion().GetSize(), { 6, 6 }); + ITK_TEST_EXPECT_TRUE(itk::Math::FloatAlmostEqual(pyramidFilter->GetMetricThreshold(), 5.41497f, 4, 1e-5f)); + + // Use default schedule for testing + pyramidFilter->SetNumberOfLevels(numLevels); + + // Verify kernel variance and radius match expectations for default schedule + KernelSizeType radius, prevRadius; + unsigned int fftCount = 0; + float sizeMetric; + for (unsigned int level = 0; level < numLevels; ++level) + { + auto schedule = pyramidFilter->GetSchedule(); + auto variance = pyramidFilter->GetVariance(level); + radius = pyramidFilter->GetKernelRadius(level); + sizeMetric = pyramidFilter->ComputeMetricValue(inputImage->GetLargestPossibleRegion().GetSize(), radius); + auto useFFT = pyramidFilter->GetUseFFT(radius); + + std::cout << "FFT will " << (useFFT ? "" : "not ") << "be used for level " << level << " with radius " << radius + << " and metric value " << sizeMetric << std::setprecision(3) << std::endl; + if (useFFT) + ++fftCount; + + for (unsigned int dim = 0; dim < ImageDimension; ++dim) + { + // Verify variance output + ITK_TEST_EXPECT_TRUE(itk::Math::AlmostEquals(variance[dim], itk::Math::sqr(0.5 * schedule[level][dim]))); + + // Verify kernel radius output + // Full calculations for default Gaussian size are outside the scope of this test + // so just test that radius decreases with level + if (level > 0) + { + ITK_TEST_EXPECT_TRUE(radius[dim] == 1 || prevRadius[dim] == 1 || radius[dim] < prevRadius[dim]); + } + else + { + prevRadius = radius; + } + } + } + + if (expectedFFTCount != -1) + { + // Test number of levels for FFT smoothing matches expectations + ITK_TEST_EXPECT_EQUAL(fftCount, expectedFFTCount); + } + + ITK_EXERCISE_BASIC_OBJECT_METHODS( + pyramidFilter, VkMultiResolutionPyramidImageFilter, MultiResolutionPyramidImageFilter); + + // Run the filter and track progress + ShowProgressObject progressWatch(pyramidFilter); + itk::SimpleMemberCommand::Pointer command; + command = itk::SimpleMemberCommand::New(); + command->SetCallbackFunction(&progressWatch, &ShowProgressObject::ShowProgress); + pyramidFilter->AddObserver(itk::ProgressEvent(), command); + pyramidFilter->Update(); + + for (unsigned int ilevel = 0; ilevel < numLevels; ++ilevel) + { + itk::WriteImage(pyramidFilter->GetOutput(ilevel), argv[2] + std::to_string(ilevel) + ".mha"); + } + + return EXIT_SUCCESS; +} diff --git a/wrapping/itkVkMultiResolutionPyramidImageFilter.wrap b/wrapping/itkVkMultiResolutionPyramidImageFilter.wrap new file mode 100644 index 00000000..3e0cd79e --- /dev/null +++ b/wrapping/itkVkMultiResolutionPyramidImageFilter.wrap @@ -0,0 +1,3 @@ +itk_wrap_class("itk::VkMultiResolutionPyramidImageFilter" POINTER) + itk_wrap_image_filter("${WRAP_ITK_SCALAR}" 2) +itk_end_wrap_class() diff --git a/wrapping/itkVkMultiResolutionPyramidImageFilterFactory.wrap b/wrapping/itkVkMultiResolutionPyramidImageFilterFactory.wrap new file mode 100644 index 00000000..f0854df4 --- /dev/null +++ b/wrapping/itkVkMultiResolutionPyramidImageFilterFactory.wrap @@ -0,0 +1 @@ +itk_wrap_simple_class("itk::VkMultiResolutionPyramidImageFilterFactory" POINTER)