diff --git a/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.h b/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.h index ec44510d45f..89c6b4d8d66 100644 --- a/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.h +++ b/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.h @@ -593,6 +593,42 @@ class ITK_TEMPLATE_EXPORT ImageRegistrationMethodv4 : public ProcessObject { ptr = IdentityTransform::New().GetPointer(); } + + /** Initialize members according to the required number of levels. + * + * Initialize the shrink factors, smoothing sigmas, and metric sampling percentage values. If + * decreasingConsecutiveShrinkFactors is true, the shrink factors will be initialized to decreasing integer values + * starting from the number of levels minus one (e.g. if the number of level is $3$, they will be initialized to the + * set ${2,1,0}$; if false, they will be initialized to all $1$'s. An equivalent logic applies to the smoothing sigma + * values. + */ + void + InitializeToLevels(const SizeValueType numberOfLevels, + const bool decreasingConsecutiveShrinkFactors, + const bool decreasingConsecutiveSmoothingSigmas); + + /** Set the metric sampling random number generator seed to the specified value. */ + void + SetMetricSamplingSeed(int seed); + + /** Clear image masks. */ + void + ClearImageMasks(); + + /** Clear registration entities. + * + * Clears smoothing images and point sets. + */ + void + ClearRegistrationEntities(); + + /** Clear smoothing images. */ + void + ClearSmoothingImages(); + + /** Clear point sets. */ + void + ClearPointSets(); }; } // end namespace itk diff --git a/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.hxx b/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.hxx index 353e6d5b243..a9a6ab1f7b1 100644 --- a/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.hxx +++ b/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.hxx @@ -102,25 +102,10 @@ ImageRegistrationMethodv4SetNumberOfLevels(3); - this->m_ShrinkFactorsPerLevel.resize(this->m_NumberOfLevels); - ShrinkFactorsPerDimensionContainerType shrinkFactors; - shrinkFactors.Fill(2); - this->m_ShrinkFactorsPerLevel[0] = shrinkFactors; - shrinkFactors.Fill(1); - this->m_ShrinkFactorsPerLevel[1] = shrinkFactors; - shrinkFactors.Fill(1); - this->m_ShrinkFactorsPerLevel[2] = shrinkFactors; - - this->m_SmoothingSigmasPerLevel.SetSize(this->m_NumberOfLevels); - this->m_SmoothingSigmasPerLevel[0] = 2; - this->m_SmoothingSigmasPerLevel[1] = 1; - this->m_SmoothingSigmasPerLevel[2] = 0; - this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits = true; this->m_ReseedIterator = false; - this->m_RandomSeed = Statistics::MersenneTwisterRandomVariateGenerator::GetNextSeed(); - this->m_CurrentRandomSeed = this->m_RandomSeed; + this->SetMetricSamplingSeed(Statistics::MersenneTwisterRandomVariateGenerator::GetNextSeed()); this->m_MetricSamplingStrategy = MetricSamplingStrategyEnum::NONE; this->m_MetricSamplingPercentagePerLevel.SetSize(this->m_NumberOfLevels); @@ -425,10 +410,7 @@ ImageRegistrationMethodv4m_VirtualDomainImage->Allocate(); } - this->m_FixedImageMasks.clear(); - this->m_FixedImageMasks.resize(this->m_NumberOfMetrics); - this->m_MovingImageMasks.clear(); - this->m_MovingImageMasks.resize(this->m_NumberOfMetrics); + this->ClearImageMasks(); for (SizeValueType n = 0; n < this->m_NumberOfMetrics; ++n) { @@ -595,14 +577,7 @@ ImageRegistrationMethodv4m_FixedSmoothImages.clear(); - this->m_FixedSmoothImages.resize(this->m_NumberOfMetrics); - this->m_MovingSmoothImages.clear(); - this->m_MovingSmoothImages.resize(this->m_NumberOfMetrics); - this->m_FixedPointSets.clear(); - this->m_FixedPointSets.resize(this->m_NumberOfMetrics); - this->m_MovingPointSets.clear(); - this->m_MovingPointSets.resize(this->m_NumberOfMetrics); + this->ClearRegistrationEntities(); for (SizeValueType n = 0; n < this->m_NumberOfMetrics; ++n) { @@ -746,6 +721,43 @@ ImageRegistrationMethodv4 +void +ImageRegistrationMethodv4::ClearImageMasks() +{ + this->m_FixedImageMasks.clear(); + this->m_FixedImageMasks.resize(this->m_NumberOfMetrics); + this->m_MovingImageMasks.clear(); + this->m_MovingImageMasks.resize(this->m_NumberOfMetrics); +} + +template +void +ImageRegistrationMethodv4::ClearRegistrationEntities() +{ + this->ClearPointSets(); + this->ClearSmoothingImages(); +} + +template +void +ImageRegistrationMethodv4::ClearSmoothingImages() +{ + this->m_FixedSmoothImages.clear(); + this->m_FixedSmoothImages.resize(this->m_NumberOfMetrics); + this->m_MovingSmoothImages.clear(); + this->m_MovingSmoothImages.resize(this->m_NumberOfMetrics); +} + +template +void +ImageRegistrationMethodv4::ClearPointSets() +{ + this->m_FixedPointSets.clear(); + this->m_FixedPointSets.resize(this->m_NumberOfMetrics); + this->m_MovingPointSets.clear(); + this->m_MovingPointSets.resize(this->m_NumberOfMetrics); +} template void @@ -907,6 +919,91 @@ ImageRegistrationMethodv4Modified(); } + + this->InitializeToLevels(m_NumberOfLevels, false, false); + this->InitializeToLevels(m_NumberOfLevels, true, true); +} + +template +void +ImageRegistrationMethodv4::InitializeToLevels( + const SizeValueType numberOfLevels, + const bool decreasingConsecutiveShrinkFactors, + const bool decreasingConsecutiveSmoothingSigmas) +{ + this->m_NumberOfLevels = numberOfLevels; + + // Set default transform adaptors which don't do anything to the input transform + // Similarly, fill in some default values for the shrink factors, smoothing sigmas, + // and learning rates. + + + // Check why this is not initialized in the constructor + this->m_TransformParametersAdaptorsPerLevel.clear(); + for (SizeValueType level = 0; level < this->m_NumberOfLevels; ++level) + { + this->m_TransformParametersAdaptorsPerLevel.push_back(nullptr); + } + + this->m_ShrinkFactorsPerLevel.resize(this->m_NumberOfLevels); + if (!decreasingConsecutiveShrinkFactors) + { + auto shrinkFactors = ShrinkFactorsArrayType(this->m_NumberOfLevels); + shrinkFactors.Fill(1); + this->SetShrinkFactorsPerLevel(shrinkFactors); + + // Check against the below + for (SizeValueType level = 0; level < this->m_NumberOfLevels; ++level) + { + ShrinkFactorsPerDimensionContainerType shrinkFactors1; + shrinkFactors1.Fill(1); + this->SetShrinkFactorsPerDimension(level, shrinkFactors1); + } + } + else + { + ShrinkFactorsPerDimensionContainerType shrinkFactors; + shrinkFactors.Fill(m_NumberOfLevels - 1); + this->SetShrinkFactorsPerDimension(0, shrinkFactors); + + auto shrinkFactors2 = ShrinkFactorsArrayType(this->m_NumberOfLevels); + for (SizeValueType level = 1; level < this->m_NumberOfLevels; ++level) + { + shrinkFactors.Fill(1); + this->SetShrinkFactorsPerDimension(level, shrinkFactors); + } + + // Check against the below + ShrinkFactorsPerDimensionContainerType shrinkFactors1; + shrinkFactors1.Fill(2); + this->m_ShrinkFactorsPerLevel[0] = shrinkFactors1; + shrinkFactors1.Fill(1); + this->m_ShrinkFactorsPerLevel[1] = shrinkFactors1; + shrinkFactors1.Fill(1); + this->m_ShrinkFactorsPerLevel[2] = shrinkFactors1; + } + + + this->m_SmoothingSigmasPerLevel.SetSize(this->m_NumberOfLevels); + if (!decreasingConsecutiveSmoothingSigmas) + { + this->m_SmoothingSigmasPerLevel.Fill(1.0); + } + else + { + for (SizeValueType level = 0; level < this->m_NumberOfLevels; ++level) + { + this->m_SmoothingSigmasPerLevel[level] = m_NumberOfLevels - level - 1; + } + + // Check against the below + // this->m_SmoothingSigmasPerLevel[0] = 2; + // this->m_SmoothingSigmasPerLevel[1] = 1; + // this->m_SmoothingSigmasPerLevel[2] = 0; + } + + this->m_MetricSamplingPercentagePerLevel.SetSize(this->m_NumberOfLevels); + this->m_MetricSamplingPercentagePerLevel.Fill(1.0); } /** @@ -1288,12 +1385,19 @@ ImageRegistrationMethodv4SetMetricSamplingSeed(seed); this->Modified(); } } +template +void +ImageRegistrationMethodv4::SetMetricSamplingSeed( + int seed) +{ + m_RandomSeed = seed; + m_CurrentRandomSeed = seed; +} template void