From 97b878c4905c7a3333d7bc963c7cde94c39fa318 Mon Sep 17 00:00:00 2001 From: Samuel Gerber Date: Thu, 11 Mar 2021 14:35:07 -0500 Subject: [PATCH] ENH: Add option to access index in point set registration - Add superclass to PointSetToPointSetMetricv4 to permit index based access in GetLocalNeighborhodXXX methods. - Fix python wrapping. --- ...itkRegularStepGradientDescentOptimizerv4.h | 2 +- .../itkPointSetToPointSetMetricWithIndexv4.h | 484 ++++++++++++ ...itkPointSetToPointSetMetricWithIndexv4.hxx | 697 ++++++++++++++++++ .../include/itkPointSetToPointSetMetricv4.h | 328 ++------- .../include/itkPointSetToPointSetMetricv4.hxx | 653 ---------------- ...ideanDistancePointSetToPointSetMetric.wrap | 2 +- ...tationBasedPointSetToPointSetMetricv4.wrap | 2 +- ...rvatTsallisPointSetToPointSetMetricv4.wrap | 2 +- .../itkLabeledPointSetToPointSetMetricv4.wrap | 2 +- ...tkPointSetToPointSetMetricWithIndexv4.wrap | 12 + .../itkPointSetToPointSetMetricv4.wrap | 12 + .../include/itkImageRegistrationMethodv4.h | 4 +- 12 files changed, 1256 insertions(+), 944 deletions(-) create mode 100644 Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h create mode 100644 Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx create mode 100644 Modules/Registration/Metricsv4/wrapping/itkPointSetToPointSetMetricWithIndexv4.wrap create mode 100644 Modules/Registration/Metricsv4/wrapping/itkPointSetToPointSetMetricv4.wrap diff --git a/Modules/Numerics/Optimizersv4/include/itkRegularStepGradientDescentOptimizerv4.h b/Modules/Numerics/Optimizersv4/include/itkRegularStepGradientDescentOptimizerv4.h index 8e846a69cb3..e4b88b26436 100644 --- a/Modules/Numerics/Optimizersv4/include/itkRegularStepGradientDescentOptimizerv4.h +++ b/Modules/Numerics/Optimizersv4/include/itkRegularStepGradientDescentOptimizerv4.h @@ -43,7 +43,7 @@ namespace itk * * \ingroup ITKOptimizersv4 */ -template +template class ITK_TEMPLATE_EXPORT RegularStepGradientDescentOptimizerv4 : public GradientDescentOptimizerv4Template { diff --git a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h new file mode 100644 index 00000000000..3d4dd6f463f --- /dev/null +++ b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.h @@ -0,0 +1,484 @@ +/*========================================================================= + * + * Copyright NumFOCUS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef itkPointSetToPointSetMetricWithIndexv4_h +#define itkPointSetToPointSetMetricWithIndexv4_h + +#include "itkObjectToObjectMetric.h" + +#include "itkFixedArray.h" +#include "itkPointsLocator.h" +#include "itkPointSet.h" + +namespace itk +{ +/** \class PointSetToPointSetMetricWithIndexv4 + * \brief Computes similarity between two point sets. + * + * This class is templated over the type of the two point-sets. It + * expects a Transform to be plugged in for each of fixed and moving + * point sets. The transforms default to IdenityTransform types. This particular + * class is the base class for a hierarchy of point-set to point-set metrics. + * + * This class computes a value that measures the similarity between the fixed + * point-set and the moving point-set in the moving domain. The fixed point set + * is transformed into the virtual domain by computing the inverse of the + * fixed transform, then transformed into the moving domain using the + * moving transform. + * + * Since the \c PointSet class permits each \c Point to be associated with a + * \c PixelType, there are potential applications which could make use of + * this additional information. For example, the derived \c LabeledPointSetToPointSetMetric + * class uses the \c PixelType as a \c LabelEnum for estimating total metric values + * and gradients from the individual label-wise point subset metric and derivatives + * + * If a virtual domain is not defined by the user, one of two things happens: + * 1) If the moving transform is a global type, then the virtual domain is + * left undefined and every point is considered to be within the virtual domain. + * 2) If the moving transform is a local-support type, then the virtual domain + * is taken during initialization from the moving transform displacement field, + * and all fixed points are verified to be within the virtual domain after + * transformation by the inverse fixed transform. Points outside the virtual + * domain are not used. See GetNumberOfValidPoints() to verify how many fixed + * points were used during evaluation. + * + * See ObjectToObjectMetric documentation for more discussion on the virtual domain. + * + * \note When used with an RegistrationParameterScalesEstimator estimator, a VirtualDomainPointSet + * must be defined and assigned to the estimator, for use in shift estimation. + * The virtual domain point set can be retrieved from the metric using the + * GetVirtualTransformedPointSet() method. + * + * \ingroup ITKMetricsv4 + */ + +template +class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricWithIndexv4 + : public ObjectToObjectMetric, + TInternalComputationValueType> +{ +public: + ITK_DISALLOW_COPY_AND_MOVE(PointSetToPointSetMetricWithIndexv4); + + /** Standard class type aliases. */ + using Self = PointSetToPointSetMetricWithIndexv4; + using Superclass = ObjectToObjectMetric, + TInternalComputationValueType>; + using Pointer = SmartPointer; + using ConstPointer = SmartPointer; + + /** Run-time type information (and related methods). */ + itkTypeMacro(PointSetToPointSetMetricWithIndexv4, ObjectToObjectMetric); + + /** Type of the measure. */ + using MeasureType = typename Superclass::MeasureType; + + /** Type of the parameters. */ + using ParametersType = typename Superclass::ParametersType; + using ParametersValueType = typename Superclass::ParametersValueType; + using NumberOfParametersType = typename Superclass::NumberOfParametersType; + + /** Type of the derivative. */ + using DerivativeType = typename Superclass::DerivativeType; + + /** Transform types from Superclass*/ + using FixedTransformType = typename Superclass::FixedTransformType; + using FixedTransformPointer = typename Superclass::FixedTransformPointer; + using FixedInputPointType = typename Superclass::FixedInputPointType; + using FixedOutputPointType = typename Superclass::FixedOutputPointType; + using FixedTransformParametersType = typename Superclass::FixedTransformParametersType; + + using MovingTransformType = typename Superclass::MovingTransformType; + using MovingTransformPointer = typename Superclass::MovingTransformPointer; + using MovingInputPointType = typename Superclass::MovingInputPointType; + using MovingOutputPointType = typename Superclass::MovingOutputPointType; + using MovingTransformParametersType = typename Superclass::MovingTransformParametersType; + + using JacobianType = typename Superclass::JacobianType; + using FixedTransformJacobianType = typename Superclass::FixedTransformJacobianType; + using MovingTransformJacobianType = typename Superclass::MovingTransformJacobianType; + + using DisplacementFieldTransformType = typename Superclass::MovingDisplacementFieldTransformType; + + using ObjectType = typename Superclass::ObjectType; + + /** Dimension type */ + using DimensionType = typename Superclass::DimensionType; + + /** Type of the fixed point set. */ + using FixedPointSetType = TFixedPointSet; + using FixedPointType = typename TFixedPointSet::PointType; + using FixedPixelType = typename TFixedPointSet::PixelType; + using FixedPointsContainer = typename TFixedPointSet::PointsContainer; + + static constexpr DimensionType FixedPointDimension = Superclass::FixedDimension; + + /** Type of the moving point set. */ + using MovingPointSetType = TMovingPointSet; + using MovingPointType = typename TMovingPointSet::PointType; + using MovingPixelType = typename TMovingPointSet::PixelType; + using MovingPointsContainer = typename TMovingPointSet::PointsContainer; + + static constexpr DimensionType MovingPointDimension = Superclass::MovingDimension; + + /** + * typedefs for the data types used in the point set metric calculations. + * It is assumed that the constants of the fixed point set, such as the + * point dimension, are the same for the "common space" in which the metric + * calculation occurs. + */ + static constexpr DimensionType PointDimension = Superclass::FixedDimension; + + using PointType = FixedPointType; + using PixelType = FixedPixelType; + using CoordRepType = typename PointType::CoordRepType; + using PointsContainer = FixedPointsContainer; + using PointsConstIterator = typename PointsContainer::ConstIterator; + using PointIdentifier = typename PointsContainer::ElementIdentifier; + + /** Typedef for points locator class to speed up finding neighboring points */ + using PointsLocatorType = PointsLocator; + using NeighborsIdentifierType = typename PointsLocatorType::NeighborsIdentifierType; + + using FixedTransformedPointSetType = PointSet; + using MovingTransformedPointSetType = PointSet; + + using DerivativeValueType = typename DerivativeType::ValueType; + using LocalDerivativeType = FixedArray; + + /** Types for the virtual domain */ + using VirtualImageType = typename Superclass::VirtualImageType; + using VirtualImagePointer = typename Superclass::VirtualImagePointer; + using VirtualPixelType = typename Superclass::VirtualPixelType; + using VirtualRegionType = typename Superclass::VirtualRegionType; + using VirtualSizeType = typename Superclass::VirtualSizeType; + using VirtualSpacingType = typename Superclass::VirtualSpacingType; + using VirtualOriginType = typename Superclass::VirtualPointType; + using VirtualPointType = typename Superclass::VirtualPointType; + using VirtualDirectionType = typename Superclass::VirtualDirectionType; + using VirtualRadiusType = typename Superclass::VirtualSizeType; + using VirtualIndexType = typename Superclass::VirtualIndexType; + using VirtualPointSetType = typename Superclass::VirtualPointSetType; + using VirtualPointSetPointer = typename Superclass::VirtualPointSetPointer; + + /** Set fixed point set*/ + void + SetFixedObject(const ObjectType * object) override + { + auto * pointSet = dynamic_cast(const_cast(object)); + if (pointSet != nullptr) + { + this->SetFixedPointSet(pointSet); + } + else + { + itkExceptionMacro("Incorrect object type. Should be a point set."); + } + } + + /** Set moving point set*/ + void + SetMovingObject(const ObjectType * object) override + { + auto * pointSet = dynamic_cast(const_cast(object)); + if (pointSet != nullptr) + { + this->SetMovingPointSet(pointSet); + } + else + { + itkExceptionMacro("Incorrect object type. Should be a point set."); + } + } + + /** Get/Set the fixed pointset. */ + itkSetConstObjectMacro(FixedPointSet, FixedPointSetType); + itkGetConstObjectMacro(FixedPointSet, FixedPointSetType); + + /** Get the fixed transformed point set. */ + itkGetModifiableObjectMacro(FixedTransformedPointSet, FixedTransformedPointSetType); + + /** Get/Set the moving point set. */ + itkSetConstObjectMacro(MovingPointSet, MovingPointSetType); + itkGetConstObjectMacro(MovingPointSet, MovingPointSetType); + + /** Get the moving transformed point set. */ + itkGetModifiableObjectMacro(MovingTransformedPointSet, MovingTransformedPointSetType); + + /** + * For now return the number of points used in the value/derivative calculations. + */ + SizeValueType + GetNumberOfComponents() const; + + /** + * This method returns the value of the metric based on the current + * transformation(s). This function can be redefined in derived classes + * but many point set metrics follow the same structure---one iterates + * through the points and, for each point a metric value is calculated. + * The summation of these individual point metric values gives the total + * value of the metric. Note that this might not be applicable to all + * point set metrics. For those cases, the developer will have to redefine + * the GetValue() function. + */ + MeasureType + GetValue() const override; + + /** + * This method returns the derivative based on the current + * transformation(s). This function can be redefined in derived classes + * but many point set metrics follow the same structure---one iterates + * through the points and, for each point a derivative is calculated. + * The set of all these local derivatives constitutes the total derivative. + * Note that this might not be applicable to all point set metrics. For + * those cases, the developer will have to redefine the GetDerivative() + * function. + */ + void + GetDerivative(DerivativeType &) const override; + + /** + * This method returns the derivative and value based on the current + * transformation(s). This function can be redefined in derived classes + * but many point set metrics follow the same structure---one iterates + * through the points and, for each point a derivative and value is calculated. + * The set of all these local derivatives/values constitutes the total + * derivative and value. Note that this might not be applicable to all + * point set metrics. For those cases, the developer will have to redefine + * the GetValue() and GetDerivative() functions. + */ + void + GetValueAndDerivative(MeasureType &, DerivativeType &) const override; + + /** + * Get the virtual point set, derived from the fixed point set. + * If the virtual point set has not yet been derived, it will be + * in this call. */ + const VirtualPointSetType * + GetVirtualTransformedPointSet() const; + + /** + * Initialize the metric by making sure that all the components + * are present and plugged together correctly. + */ + void + Initialize() override; + + bool + SupportsArbitraryVirtualDomainSamples() const override + { + /* An arbitrary point in the virtual domain will not always + * correspond to a point within either point set. */ + return false; + } + + /** + * By default, the point set metric derivative for a displacement field transform + * is stored by saving the gradient for every voxel in the displacement field (see + * the function StorePointDerivative()). Since the "fixed points" will typically + * constitute a sparse set, this means that the field will have zero gradient values + * at every voxel that doesn't have a corresponding point. This might cause additional + * computation time for certain transforms (e.g. B-spline SyN). To avoid this, this + * option permits storing the point derivative only at the fixed point locations. + * If this variable is set to false, then the derivative array will be of length + * = PointDimension * m_FixedPointSet->GetNumberOfPoints(). + */ + itkSetMacro(StoreDerivativeAsSparseFieldForLocalSupportTransforms, bool); + itkGetConstMacro(StoreDerivativeAsSparseFieldForLocalSupportTransforms, bool); + itkBooleanMacro(StoreDerivativeAsSparseFieldForLocalSupportTransforms); + + /** + * + */ + itkSetMacro(CalculateValueAndDerivativeInTangentSpace, bool); + itkGetConstMacro(CalculateValueAndDerivativeInTangentSpace, bool); + itkBooleanMacro(CalculateValueAndDerivativeInTangentSpace); + +protected: + PointSetToPointSetMetricWithIndexv4(); + ~PointSetToPointSetMetricWithIndexv4() override = default; + void + PrintSelf(std::ostream & os, Indent indent) const override; + + typename FixedPointSetType::ConstPointer m_FixedPointSet; + mutable typename FixedTransformedPointSetType::Pointer m_FixedTransformedPointSet; + + mutable typename PointsLocatorType::Pointer m_FixedTransformedPointsLocator; + + typename MovingPointSetType::ConstPointer m_MovingPointSet; + mutable typename MovingTransformedPointSetType::Pointer m_MovingTransformedPointSet; + + mutable typename PointsLocatorType::Pointer m_MovingTransformedPointsLocator; + + /** Holds the fixed points after transformation into virtual domain. */ + mutable VirtualPointSetPointer m_VirtualTransformedPointSet; + + /** + * Bool set by derived classes on whether the point set data (i.e. \c PixelType) + * should be used. Default = false. + */ + bool m_UsePointSetData; + + /** + * Flag to calculate value and/or derivative at tangent space. This is needed + * for the diffeomorphic registration methods. The fixed and moving points are + * warped to the virtual domain where the metric is calculated. Derived point + * set metrics might have associated gradient information which will need to be + * warped if this flag is true. Default = false. + */ + bool m_CalculateValueAndDerivativeInTangentSpace; + + /** + * Prepare point sets for use. */ + virtual void + InitializePointSets() const; + + /** + * Initialize to prepare for a particular iteration, generally + * an iteration of optimization. Distinct from Initialize() + * which is a one-time initialization. */ + virtual void + InitializeForIteration() const; + + /** + * Determine the number of valid fixed points. A fixed point + * is valid if, when transformed into the virtual domain using + * the inverse of the FixedTransform, it is within the defined + * virtual domain bounds. */ + virtual SizeValueType + CalculateNumberOfValidFixedPoints() const; + + /** Helper method allows for code reuse while skipping the metric value + * calculation when appropriate */ + void + CalculateValueAndDerivative(MeasureType & calculatedValue, DerivativeType & derivative, bool calculateValue) const; + + /** + * Warp the fixed point set into the moving domain based on the fixed transform, + * passing through the virtual domain and storing a virtual domain set. + * Note that the warped moving point set is of type FixedPointSetType since the transform + * takes the points from the fixed to the moving domain. + */ + void + TransformFixedAndCreateVirtualPointSet() const; + + /** + * Warp the moving point set based on the moving transform. Note that the + * warped moving point set is of type FixedPointSetType since the transform + * takes the points from the moving to the fixed domain. + * FIXME: needs update. + */ + void + TransformMovingPointSet() const; + + /** + * Build point locators for the fixed and moving point sets to speed up + * derivative and value calculations. + */ + void + InitializePointsLocators() const; + + /** + * Store a derivative from a single point in a field. + * Only relevant when active transform has local support. + */ + void + StorePointDerivative(const VirtualPointType &, const DerivativeType &, DerivativeType &) const; + + using MetricCategoryType = typename Superclass::MetricCategoryType; + + /** Get metric category */ + MetricCategoryType + GetMetricCategory() const override + { + return MetricCategoryType::POINT_SET_METRIC; + } + + virtual bool + RequiresMovingPointsLocator() const + { + return true; + }; + + virtual bool + RequiresFixedPointsLocator() const + { + return true; + }; + + /** + * Function to be defined in the appropriate derived classes. Calculates + * the local metric value for a single point. The \c PixelType may or + * may not be used. See class description for further explanation. + */ + virtual MeasureType + GetLocalNeighborhoodValueWithIndex(const PointIdentifier &, const PointType &, const PixelType & pixel) const = 0; + /** + * Function to be defined in the appropriate derived classes. Calculates + * the local metric value for a single point. The \c PixelType may or + * may not be used. See class description for further explanation. + * Default implementation calls GetLocalNeighborhoodValueAndDerivative. + */ + virtual LocalDerivativeType + GetLocalNeighborhoodDerivativeWithIndex(const PointIdentifier &, const PointType &, const PixelType & pixel) const; + + /** + * Function to be defined in the appropriate derived classes. Calculates + * the local metric value for a single point. The \c PixelType may or + * may not be used. See class description for further explanation. + */ + virtual void + GetLocalNeighborhoodValueAndDerivativeWithIndex(const PointIdentifier &, + const PointType &, + MeasureType &, + LocalDerivativeType &, + const PixelType & pixel) const = 0; + +private: + mutable bool m_MovingTransformPointLocatorsNeedInitialization; + mutable bool m_FixedTransformPointLocatorsNeedInitialization; + + // Flag to keep track of whether a warning has already been issued + // regarding the number of valid points. + mutable bool m_HaveWarnedAboutNumberOfValidPoints; + + // Flag to store derivatives at fixed point locations with the rest being zero gradient + // (default = true). + bool m_StoreDerivativeAsSparseFieldForLocalSupportTransforms; + + mutable ModifiedTimeType m_MovingTransformedPointSetTime; + mutable ModifiedTimeType m_FixedTransformedPointSetTime; + + // Create ranges over the point set for multithreaded computation of value and derivatives + using PointIdentifierPair = std::pair; + using PointIdentifierRanges = std::vector; + const PointIdentifierRanges + CreateRanges() const; +}; +} // end namespace itk + +#ifndef ITK_MANUAL_INSTANTIATION +# include "itkPointSetToPointSetMetricWithIndexv4.hxx" +#endif + +#endif diff --git a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx new file mode 100644 index 00000000000..6208ad559fa --- /dev/null +++ b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricWithIndexv4.hxx @@ -0,0 +1,697 @@ +/*========================================================================= + * + * Copyright NumFOCUS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef itkPointSetToPointSetMetricWithIndexv4_hxx +#define itkPointSetToPointSetMetricWithIndexv4_hxx + +#include "itkPointSetToPointSetMetricWithIndexv4.h" +#include "itkIdentityTransform.h" +#include "itkCompensatedSummation.h" + +namespace itk +{ + +/** Constructor */ +template +PointSetToPointSetMetricWithIndexv4:: + PointSetToPointSetMetricWithIndexv4() +{ + this->m_FixedPointSet = nullptr; // has to be provided by the user. + this->m_MovingPointSet = nullptr; // has to be provided by the user. + + this->m_FixedTransformedPointSet = nullptr; + this->m_MovingTransformedPointSet = nullptr; + this->m_VirtualTransformedPointSet = nullptr; + + this->m_FixedTransformedPointsLocator = nullptr; + this->m_MovingTransformedPointsLocator = nullptr; + + this->m_MovingTransformPointLocatorsNeedInitialization = false; + this->m_FixedTransformPointLocatorsNeedInitialization = false; + + this->m_MovingTransformedPointSetTime = this->GetMTime(); + this->m_FixedTransformedPointSetTime = this->GetMTime(); + + // We iterate over the fixed points to calculate the value and derivative. + this->SetGradientSource(ObjectToObjectMetricBaseTemplateEnums::GradientSource::GRADIENT_SOURCE_FIXED); + + this->m_HaveWarnedAboutNumberOfValidPoints = false; + + this->m_UsePointSetData = false; + + this->m_StoreDerivativeAsSparseFieldForLocalSupportTransforms = true; + + this->m_CalculateValueAndDerivativeInTangentSpace = false; +} + +/** Initialize the metric */ +template +void +PointSetToPointSetMetricWithIndexv4::Initialize() +{ + if (!this->m_FixedPointSet) + { + itkExceptionMacro("Fixed point set is not present"); + } + + if (!this->m_MovingPointSet) + { + itkExceptionMacro("Moving point set is not present"); + } + + // We don't know how to support gradient source of type moving + if (this->GetGradientSourceIncludesMoving()) + { + itkExceptionMacro("GradientSource includes GRADIENT_SOURCE_MOVING. Not supported."); + } + + // If the PointSet is provided by a source, update the source. + if (this->m_MovingPointSet->GetSource()) + { + this->m_MovingPointSet->GetSource()->Update(); + } + + // If the point set is provided by a source, update the source. + if (this->m_FixedPointSet->GetSource()) + { + this->m_FixedPointSet->GetSource()->Update(); + } + + // Check for virtual domain if needed. + // With local-support transforms we need a virtual domain in + // order to properly store the per-point derivatives. + // This will create a virtual domain that matches the DisplacementFieldTransform. + // If the virutal domain has already been set, it will + // be verified against the transform in Superclass::Initialize. + if (this->HasLocalSupport()) + { + if (!this->m_UserHasSetVirtualDomain) + { + const typename DisplacementFieldTransformType::ConstPointer displacementTransform = + this->GetMovingDisplacementFieldTransform(); + if (displacementTransform.IsNull()) + { + itkExceptionMacro("Expected the moving transform to be of type DisplacementFieldTransform or derived, " + "or a CompositeTransform with DisplacementFieldTransform as the last to have been added."); + } + using DisplacementFieldType = typename DisplacementFieldTransformType::DisplacementFieldType; + typename DisplacementFieldType::ConstPointer field = displacementTransform->GetDisplacementField(); + this->SetVirtualDomain( + field->GetSpacing(), field->GetOrigin(), field->GetDirection(), field->GetBufferedRegion()); + } + } + + // Superclass initialization. Do after checking for virtual domain. + Superclass::Initialize(); + + // Call this now for derived classes that need + // a member to be initialized during Initialize(). + this->InitializePointSets(); +} + +template +void +PointSetToPointSetMetricWithIndexv4:: + InitializePointSets() const +{ + this->TransformMovingPointSet(); + this->TransformFixedAndCreateVirtualPointSet(); + this->InitializePointsLocators(); +} + +template +void +PointSetToPointSetMetricWithIndexv4:: + InitializeForIteration() const +{ + this->InitializePointSets(); + this->m_NumberOfValidPoints = this->CalculateNumberOfValidFixedPoints(); + if (this->m_NumberOfValidPoints < this->GetNumberOfComponents() && !this->m_HaveWarnedAboutNumberOfValidPoints) + { + itkWarningMacro("Only " << this->m_NumberOfValidPoints << " of " << this->GetNumberOfComponents() + << " points are within the virtual domain, and will be used in the evaluation."); + this->m_HaveWarnedAboutNumberOfValidPoints = true; + } +} + +template +SizeValueType +PointSetToPointSetMetricWithIndexv4:: + GetNumberOfComponents() const +{ + return this->m_FixedTransformedPointSet->GetNumberOfPoints(); +} + +template +typename PointSetToPointSetMetricWithIndexv4:: + MeasureType + PointSetToPointSetMetricWithIndexv4::GetValue() const +{ + this->InitializeForIteration(); + + + // Virtual point set will be the same size as fixed point set as long as it's + // generated from the fixed point set. + if (this->m_VirtualTransformedPointSet->GetNumberOfPoints() != this->m_FixedTransformedPointSet->GetNumberOfPoints()) + { + itkExceptionMacro("Expected FixedTransformedPointSet to be the same size as VirtualTransformedPointSet."); + } + /* + * Split pointset in nWorkUnit ranges and sum individually + * This splitting is required in order to avoid having the threads + * repeatedly write to same location causing false sharing + */ + // Use STL container to make sure no unesecarry checks are performed + using FixedTransformedVectorContainer = typename FixedPointsContainer::STLContainerType; + using VirtualPointsContainer = typename VirtualPointSetType::PointsContainer; + using VirtualVectorContainer = typename VirtualPointsContainer::STLContainerType; + const VirtualVectorContainer & virtualTransformedPointSet = + this->m_VirtualTransformedPointSet->GetPoints()->CastToSTLConstContainer(); + const FixedTransformedVectorContainer & fixedTransformedPointSet = + this->m_FixedTransformedPointSet->GetPoints()->CastToSTLConstContainer(); + + PointIdentifierRanges ranges = this->CreateRanges(); + std::vector> threadValues(ranges.size()); + std::function sumNeighborhoodValues = + [this, &threadValues, &ranges, &virtualTransformedPointSet, &fixedTransformedPointSet](SizeValueType rangeIndex) { + CompensatedSummation threadValue = 0; + PixelType pixel; + NumericTraits::SetLength(pixel, 1); + + + for (PointIdentifier index = ranges[rangeIndex].first; index < ranges[rangeIndex].second; index++) + { + if (this->IsInsideVirtualDomain(virtualTransformedPointSet[index])) + { + if (this->m_UsePointSetData) + { + bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); + if (!doesPointDataExist) + { + itkExceptionMacro("The corresponding data for point (pointId = " << index << ") does not exist."); + } + } + threadValue += this->GetLocalNeighborhoodValueWithIndex(index, fixedTransformedPointSet[index], pixel); + } + } + threadValues[rangeIndex] = threadValue; + }; + + // Sum per thread + MultiThreaderBase::New()->ParallelizeArray( + (SizeValueType)0, (SizeValueType)ranges.size(), sumNeighborhoodValues, nullptr); + // Join sums + CompensatedSummation value = 0; + for (unsigned int i = 0; i < threadValues.size(); i++) + { + value += threadValues[i]; + } + + DerivativeType derivative; + MeasureType valueSum = value.GetSum(); + if (this->VerifyNumberOfValidPoints(valueSum, derivative)) + { + valueSum /= static_cast(this->m_NumberOfValidPoints); + } + this->m_Value = valueSum; + + return valueSum; +} + +template +void +PointSetToPointSetMetricWithIndexv4::GetDerivative( + DerivativeType & derivative) const +{ + MeasureType value = NumericTraits::ZeroValue(); + this->CalculateValueAndDerivative(value, derivative, false); +} + +template +void +PointSetToPointSetMetricWithIndexv4:: + GetValueAndDerivative(MeasureType & value, DerivativeType & derivative) const +{ + this->CalculateValueAndDerivative(value, derivative, true); +} + +template +void +PointSetToPointSetMetricWithIndexv4:: + CalculateValueAndDerivative(MeasureType & calculatedValue, DerivativeType & derivative, bool calculateValue) const +{ + this->InitializeForIteration(); + + // Virtual point set will be the same size as fixed point set as long as it's + // generated from the fixed point set. + if (this->m_VirtualTransformedPointSet->GetNumberOfPoints() != this->m_FixedTransformedPointSet->GetNumberOfPoints()) + { + itkExceptionMacro("Expected FixedTransformedPointSet to be the same size as VirtualTransformedPointSet."); + } + + derivative.SetSize(this->GetNumberOfParameters()); + if (!this->GetStoreDerivativeAsSparseFieldForLocalSupportTransforms()) + { + derivative.SetSize(PointDimension * this->m_FixedTransformedPointSet->GetNumberOfPoints()); + } + derivative.Fill(NumericTraits::ZeroValue()); + + /* + * Split pointset in nWorkUnits ranges and sum individually + * This splitting is required in order to avoid having the threads + * repeatedly write to same location causing false sharing + */ + // GetNumberOfLocalParameters is not trhead safe in itkCompositeTransform + NumberOfParametersType numberOfLocalParameters = this->GetNumberOfLocalParameters(); + PointIdentifierRanges ranges = this->CreateRanges(); + std::vector> threadValues(ranges.size()); + using CompensatedDerivative = typename std::vector>; + std::vector threadDerivatives(ranges.size()); + std::function sumNeighborhoodValues = + [this, &derivative, &threadDerivatives, &threadValues, &ranges, &calculateValue, &numberOfLocalParameters]( + SizeValueType rangeIndex) { + // Use STL container to make sure no unesecarry checks are performed + using FixedTransformedVectorContainer = typename FixedPointsContainer::STLContainerType; + using VirtualPointsContainer = typename VirtualPointSetType::PointsContainer; + using VirtualVectorContainer = typename VirtualPointsContainer::STLContainerType; + const VirtualVectorContainer & virtualTransformedPointSet = + this->m_VirtualTransformedPointSet->GetPoints()->CastToSTLConstContainer(); + const FixedTransformedVectorContainer & fixedTransformedPointSet = + this->m_FixedTransformedPointSet->GetPoints()->CastToSTLConstContainer(); + + MovingTransformJacobianType jacobian(MovingPointDimension, numberOfLocalParameters); + MovingTransformJacobianType jacobianCache; + + DerivativeType threadLocalTransformDerivative(numberOfLocalParameters); + threadLocalTransformDerivative.Fill(NumericTraits::ZeroValue()); + + CompensatedDerivative threadDerivativeSum(numberOfLocalParameters); + + CompensatedSummation threadValue; + PixelType pixel; + NumericTraits::SetLength(pixel, 1); + for (PointIdentifier index = ranges[rangeIndex].first; index < ranges[rangeIndex].second; index++) + { + MeasureType pointValue = NumericTraits::ZeroValue(); + LocalDerivativeType pointDerivative; + + /* Verify the virtual point is in the virtual domain. + * If user hasn't defined a virtual space, and the active transform is not + * a displacement field transform type, then this will always return true. */ + if (!this->IsInsideVirtualDomain(virtualTransformedPointSet[index])) + { + continue; + } + + if (this->m_UsePointSetData) + { + bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); + if (!doesPointDataExist) + { + itkExceptionMacro("The corresponding data for point with id " << index << " does not exist."); + } + } + + if (calculateValue) + { + this->GetLocalNeighborhoodValueAndDerivativeWithIndex( + index, fixedTransformedPointSet[index], pointValue, pointDerivative, pixel); + threadValue += pointValue; + } + else + { + pointDerivative = + this->GetLocalNeighborhoodDerivativeWithIndex(index, fixedTransformedPointSet[index], pixel); + } + + // Map into parameter space + threadLocalTransformDerivative.Fill(NumericTraits::ZeroValue()); + + if (this->m_CalculateValueAndDerivativeInTangentSpace) + { + for (DimensionType d = 0; d < PointDimension; ++d) + { + threadLocalTransformDerivative[d] += pointDerivative[d]; + } + } + else + { + this->GetMovingTransform()->ComputeJacobianWithRespectToParametersCachedTemporaries( + virtualTransformedPointSet[index], jacobian, jacobianCache); + + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; par++) + { + for (DimensionType d = 0; d < PointDimension; ++d) + { + threadLocalTransformDerivative[par] += jacobian(d, par) * pointDerivative[d]; + } + } + } + // For local-support transforms, store the per-point result + if (this->HasLocalSupport() || this->m_CalculateValueAndDerivativeInTangentSpace) + { + if (this->GetStoreDerivativeAsSparseFieldForLocalSupportTransforms()) + { + this->StorePointDerivative(virtualTransformedPointSet[index], threadLocalTransformDerivative, derivative); + } + else + { + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; par++) + { + derivative[this->GetNumberOfLocalParameters() * index + par] = threadLocalTransformDerivative[par]; + } + } + } + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; par++) + { + threadDerivativeSum[par] += threadLocalTransformDerivative[par]; + } + } + threadValues[rangeIndex] = threadValue; + threadDerivatives[rangeIndex] = threadDerivativeSum; + }; + + // Sum per thread + MultiThreaderBase::New()->ParallelizeArray( + (SizeValueType)0, (SizeValueType)ranges.size(), sumNeighborhoodValues, nullptr); + + // Sum thread results + CompensatedSummation value = 0; + for (unsigned int i = 0; i < threadValues.size(); i++) + { + value += threadValues[i]; + } + MeasureType valueSum = value.GetSum(); + + if (this->VerifyNumberOfValidPoints(valueSum, derivative)) + { + // For global-support transforms, average the accumulated derivative result + if (!this->HasLocalSupport() && !this->m_CalculateValueAndDerivativeInTangentSpace) + { + CompensatedDerivative localTransformDerivative(numberOfLocalParameters); + for (unsigned int i = 0; i < threadDerivatives.size(); i++) + { + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; par++) + { + localTransformDerivative[par] += threadDerivatives[i][par]; + } + } + derivative.SetSize(numberOfLocalParameters); + for (NumberOfParametersType par = 0; par < numberOfLocalParameters; par++) + { + derivative[par] = + localTransformDerivative[par].GetSum() / static_cast(this->m_NumberOfValidPoints); + } + } + valueSum /= static_cast(this->m_NumberOfValidPoints); + } + calculatedValue = valueSum; + this->m_Value = valueSum; +} + +template +SizeValueType +PointSetToPointSetMetricWithIndexv4:: + CalculateNumberOfValidFixedPoints() const +{ + // Determine the number of valid fixed points, using + // their positions in the virtual domain. + SizeValueType numberOfValidPoints = NumericTraits::ZeroValue(); + PointsConstIterator virtualIt = this->m_VirtualTransformedPointSet->GetPoints()->Begin(); + while (virtualIt != this->m_VirtualTransformedPointSet->GetPoints()->End()) + { + if (this->IsInsideVirtualDomain(virtualIt.Value())) + { + ++numberOfValidPoints; + } + ++virtualIt; + } + return numberOfValidPoints; +} + +template +void +PointSetToPointSetMetricWithIndexv4:: + StorePointDerivative(const VirtualPointType & virtualPoint, + const DerivativeType & pointDerivative, + DerivativeType & field) const +{ + // Update derivative field at some index. + // This requires the active transform displacement field to be the + // same size as virtual domain, and that VirtualImage PixelType + // is scalar (both of which are verified during Metric initialization). + try + { + OffsetValueType offset = + this->ComputeParameterOffsetFromVirtualPoint(virtualPoint, this->GetNumberOfLocalParameters()); + for (NumberOfParametersType i = 0; i < this->GetNumberOfLocalParameters(); i++) + { + /* Be sure to *add* here and not assign. Required for proper behavior + * with multi-variate metric. */ + field[offset + i] += pointDerivative[i]; + } + } + catch (ExceptionObject & exc) + { + std::string msg("Caught exception: \n"); + msg += exc.what(); + ExceptionObject err(__FILE__, __LINE__, msg); + throw err; + } +} + +template +typename PointSetToPointSetMetricWithIndexv4:: + LocalDerivativeType + PointSetToPointSetMetricWithIndexv4:: + GetLocalNeighborhoodDerivativeWithIndex(const PointIdentifier & id, + const PointType & point, + const PixelType & pixel) const +{ + MeasureType measure; + LocalDerivativeType localDerivative; + this->GetLocalNeighborhoodValueAndDerivativeWithIndex(id, point, measure, localDerivative, pixel); + return localDerivative; +} + +template +void +PointSetToPointSetMetricWithIndexv4:: + TransformMovingPointSet() const +{ + // Transform the moving point set with the moving transform. + // We calculate the value and derivatives in the moving space. + bool update = !this->m_MovingTransformedPointSet; + update = update || this->m_MovingTransformedPointSetTime < this->GetMTime(); + update = update || (this->m_CalculateValueAndDerivativeInTangentSpace && + (this->m_MovingTransform->GetMTime() > this->m_MovingTransformedPointSetTime)); + if (update) + { + this->m_MovingTransformPointLocatorsNeedInitialization = true; + this->m_MovingTransformedPointSet = MovingTransformedPointSetType::New(); + this->m_MovingTransformedPointSet->Initialize(); + + typename MovingTransformType::InverseTransformBasePointer inverseTransform = + this->m_MovingTransform->GetInverseTransform(); + + typename MovingPointsContainer::ConstIterator It = this->m_MovingPointSet->GetPoints()->Begin(); + while (It != this->m_MovingPointSet->GetPoints()->End()) + { + if (this->m_CalculateValueAndDerivativeInTangentSpace) + { + PointType point = inverseTransform->TransformPoint(It.Value()); + this->m_MovingTransformedPointSet->SetPoint(It.Index(), point); + } + else + { + // evaluation is performed in moving space, so just copy + this->m_MovingTransformedPointSet->SetPoint(It.Index(), It.Value()); + } + ++It; + } + this->m_MovingTransformedPointSetTime = this->GetMTime(); + if (!this->m_CalculateValueAndDerivativeInTangentSpace) + { + this->m_MovingTransformedPointSetTime = + std::max(this->m_MovingTransformedPointSetTime, this->m_MovingTransform->GetMTime()); + } + } +} + +template +void +PointSetToPointSetMetricWithIndexv4:: + TransformFixedAndCreateVirtualPointSet() const +{ + // Transform the fixed point set through the virtual domain, and into the moving domain + bool update = !this->m_FixedTransformedPointSet || !this->m_VirtualTransformedPointSet; + update = update || this->m_FixedTransformedPointSetTime < this->GetMTime(); + update = update || (this->m_CalculateValueAndDerivativeInTangentSpace && + (this->m_FixedTransform->GetMTime() > this->m_FixedTransformedPointSetTime)); + update = update || (!this->m_CalculateValueAndDerivativeInTangentSpace && + ((this->m_FixedTransform->GetMTime() > this->m_FixedTransformedPointSetTime) || + (this->m_MovingTransform->GetMTime() > this->m_FixedTransformedPointSetTime))); + if (update) + { + this->m_FixedTransformPointLocatorsNeedInitialization = true; + this->m_FixedTransformedPointSet = FixedTransformedPointSetType::New(); + this->m_FixedTransformedPointSet->Initialize(); + this->m_VirtualTransformedPointSet = VirtualPointSetType::New(); + this->m_VirtualTransformedPointSet->Initialize(); + + using InverseTransformBasePointer = typename FixedTransformType::InverseTransformBasePointer; + InverseTransformBasePointer inverseTransform = this->m_FixedTransform->GetInverseTransform(); + + typename FixedPointsContainer::ConstIterator It = this->m_FixedPointSet->GetPoints()->Begin(); + while (It != this->m_FixedPointSet->GetPoints()->End()) + { + if (this->m_CalculateValueAndDerivativeInTangentSpace) + { + // txf into virtual space + PointType point = inverseTransform->TransformPoint(It.Value()); + this->m_VirtualTransformedPointSet->SetPoint(It.Index(), point); + this->m_FixedTransformedPointSet->SetPoint(It.Index(), point); + } + else + { + // txf into virtual space + PointType point = inverseTransform->TransformPoint(It.Value()); + this->m_VirtualTransformedPointSet->SetPoint(It.Index(), point); + // txf into moving space + point = this->m_MovingTransform->TransformPoint(point); + this->m_FixedTransformedPointSet->SetPoint(It.Index(), point); + } + ++It; + } + this->m_FixedTransformedPointSetTime = std::max(this->GetMTime(), this->m_FixedTransform->GetMTime()); + if (!this->m_CalculateValueAndDerivativeInTangentSpace) + { + this->m_FixedTransformedPointSetTime = + std::max(this->m_FixedTransformedPointSetTime, this->m_MovingTransform->GetMTime()); + } + } +} + +template +const typename PointSetToPointSetMetricWithIndexv4:: + VirtualPointSetType * + PointSetToPointSetMetricWithIndexv4:: + GetVirtualTransformedPointSet() const +{ + // First make sure the virtual point set is current. + this->TransformFixedAndCreateVirtualPointSet(); + return this->m_VirtualTransformedPointSet.GetPointer(); +} + +template +void +PointSetToPointSetMetricWithIndexv4:: + InitializePointsLocators() const +{ + if (this->RequiresFixedPointsLocator() && this->m_FixedTransformPointLocatorsNeedInitialization) + { + if (!this->m_FixedTransformedPointSet) + { + itkExceptionMacro("The fixed transformed point set does not exist."); + } + if (!this->m_FixedTransformedPointsLocator) + { + this->m_FixedTransformedPointsLocator = PointsLocatorType::New(); + } + this->m_FixedTransformedPointsLocator->SetPoints(this->m_FixedTransformedPointSet->GetPoints()); + this->m_FixedTransformedPointsLocator->Initialize(); + this->m_FixedTransformPointLocatorsNeedInitialization = false; + } + + if (this->RequiresMovingPointsLocator() && this->m_MovingTransformPointLocatorsNeedInitialization) + { + if (!this->m_MovingTransformedPointSet) + { + itkExceptionMacro("The moving transformed point set does not exist."); + } + if (!this->m_MovingTransformedPointsLocator) + { + this->m_MovingTransformedPointsLocator = PointsLocatorType::New(); + } + this->m_MovingTransformedPointsLocator->SetPoints(this->m_MovingTransformedPointSet->GetPoints()); + this->m_MovingTransformedPointsLocator->Initialize(); + this->m_MovingTransformPointLocatorsNeedInitialization = false; + } +} + +template +const typename PointSetToPointSetMetricWithIndexv4:: + PointIdentifierRanges + PointSetToPointSetMetricWithIndexv4::CreateRanges() + const +{ + PointIdentifier nPoints = this->m_FixedTransformedPointSet->GetNumberOfPoints(); + PointIdentifier nWorkUnits = MultiThreaderBase::New()->GetNumberOfWorkUnits(); + if (nWorkUnits > nPoints || MultiThreaderBase::New()->GetMaximumNumberOfThreads() <= 1) + { + nWorkUnits = 1; + } + PointIdentifier startRange = 0; + PointIdentifierRanges ranges; + for (PointIdentifier p = 1; p < nWorkUnits; ++p) + { + PointIdentifier endRange = (p * nPoints) / (double)nWorkUnits; + ranges.push_back(PointIdentifierPair(startRange, endRange)); + startRange = endRange; + } + ranges.push_back(PointIdentifierPair(startRange, nPoints)); + + return ranges; +} + + +/** PrintSelf */ +template +void +PointSetToPointSetMetricWithIndexv4::PrintSelf( + std::ostream & os, + Indent indent) const +{ + Superclass::PrintSelf(os, indent); + os << indent << "Fixed PointSet: " << this->m_FixedPointSet.GetPointer() << std::endl; + os << indent << "Fixed Transform: " << this->m_FixedTransform.GetPointer() << std::endl; + os << indent << "Moving PointSet: " << this->m_MovingPointSet.GetPointer() << std::endl; + os << indent << "Moving Transform: " << this->m_MovingTransform.GetPointer() << std::endl; + + os << indent << "Store derivative as sparse field = "; + if (this->m_StoreDerivativeAsSparseFieldForLocalSupportTransforms) + { + os << "true." << std::endl; + } + else + { + os << "false." << std::endl; + } + + os << indent << "Calculate in tangent space = "; + if (this->m_CalculateValueAndDerivativeInTangentSpace) + { + os << "true." << std::endl; + } + else + { + os << "false." << std::endl; + } +} +} // end namespace itk + +#endif diff --git a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricv4.h b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricv4.h index a462254f9b3..f7812a735a6 100644 --- a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricv4.h +++ b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricv4.h @@ -18,11 +18,7 @@ #ifndef itkPointSetToPointSetMetricv4_h #define itkPointSetToPointSetMetricv4_h -#include "itkObjectToObjectMetric.h" - -#include "itkFixedArray.h" -#include "itkPointsLocator.h" -#include "itkPointSet.h" +#include "itkPointSetToPointSetMetricWithIndexv4.h" namespace itk { @@ -66,27 +62,24 @@ namespace itk * \ingroup ITKMetricsv4 */ -template +template class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricv4 - : public ObjectToObjectMetric, - TInternalComputationValueType> + : public PointSetToPointSetMetricWithIndexv4 { public: ITK_DISALLOW_COPY_AND_MOVE(PointSetToPointSetMetricv4); /** Standard class type aliases. */ using Self = PointSetToPointSetMetricv4; - using Superclass = ObjectToObjectMetric, - TInternalComputationValueType>; + using Superclass = + PointSetToPointSetMetricWithIndexv4; using Pointer = SmartPointer; using ConstPointer = SmartPointer; /** Run-time type information (and related methods). */ - itkTypeMacro(PointSetToPointSetMetricv4, ObjectToObjectMetric); + itkTypeMacro(PointSetToPointSetMetricv4, PointSetToPointSetMetricWithIndexv4); /** Type of the measure. */ using MeasureType = typename Superclass::MeasureType; @@ -118,8 +111,6 @@ class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricv4 using DisplacementFieldTransformType = typename Superclass::MovingDisplacementFieldTransformType; - using ObjectType = typename Superclass::ObjectType; - /** Dimension type */ using DimensionType = typename Superclass::DimensionType; @@ -129,7 +120,7 @@ class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricv4 using FixedPixelType = typename TFixedPointSet::PixelType; using FixedPointsContainer = typename TFixedPointSet::PointsContainer; - static constexpr DimensionType FixedPointDimension = Superclass::FixedDimension; + static constexpr DimensionType FixedPointDimension = Superclass::FixedPointDimension; /** Type of the moving point set. */ using MovingPointSetType = TMovingPointSet; @@ -137,7 +128,7 @@ class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricv4 using MovingPixelType = typename TMovingPointSet::PixelType; using MovingPointsContainer = typename TMovingPointSet::PointsContainer; - static constexpr DimensionType MovingPointDimension = Superclass::MovingDimension; + static constexpr DimensionType MovingPointDimension = Superclass::MovingPointDimension; /** * typedefs for the data types used in the point set metric calculations. @@ -145,24 +136,24 @@ class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricv4 * point dimension, are the same for the "common space" in which the metric * calculation occurs. */ - static constexpr DimensionType PointDimension = Superclass::FixedDimension; + static constexpr DimensionType PointDimension = Superclass::PointDimension; - using PointType = FixedPointType; - using PixelType = FixedPixelType; + using PointType = typename Superclass::PointType; + using PixelType = typename Superclass::PixelType; using CoordRepType = typename PointType::CoordRepType; using PointsContainer = FixedPointsContainer; using PointsConstIterator = typename PointsContainer::ConstIterator; - using PointIdentifier = typename PointsContainer::ElementIdentifier; + using PointIdentifier = typename Superclass::PointIdentifier; /** Typedef for points locator class to speed up finding neighboring points */ - using PointsLocatorType = PointsLocator; + using PointsLocatorType = typename Superclass::PointsLocatorType; using NeighborsIdentifierType = typename PointsLocatorType::NeighborsIdentifierType; - using FixedTransformedPointSetType = PointSet; - using MovingTransformedPointSetType = PointSet; + using FixedTransformedPointSetType = typename Superclass::FixedTransformedPointSetType; + using MovingTransformedPointSetType = typename Superclass::MovingTransformedPointSetType; - using DerivativeValueType = typename DerivativeType::ValueType; - using LocalDerivativeType = FixedArray; + using DerivativeValueType = typename Superclass::DerivativeValueType; + using LocalDerivativeType = typename Superclass::LocalDerivativeType; /** Types for the virtual domain */ using VirtualImageType = typename Superclass::VirtualImageType; @@ -179,95 +170,6 @@ class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricv4 using VirtualPointSetType = typename Superclass::VirtualPointSetType; using VirtualPointSetPointer = typename Superclass::VirtualPointSetPointer; - /** Set fixed point set*/ - void - SetFixedObject(const ObjectType * object) override - { - auto * pointSet = dynamic_cast(const_cast(object)); - if (pointSet != nullptr) - { - this->SetFixedPointSet(pointSet); - } - else - { - itkExceptionMacro("Incorrect object type. Should be a point set."); - } - } - - /** Set moving point set*/ - void - SetMovingObject(const ObjectType * object) override - { - auto * pointSet = dynamic_cast(const_cast(object)); - if (pointSet != nullptr) - { - this->SetMovingPointSet(pointSet); - } - else - { - itkExceptionMacro("Incorrect object type. Should be a point set."); - } - } - - /** Get/Set the fixed pointset. */ - itkSetConstObjectMacro(FixedPointSet, FixedPointSetType); - itkGetConstObjectMacro(FixedPointSet, FixedPointSetType); - - /** Get the fixed transformed point set. */ - itkGetModifiableObjectMacro(FixedTransformedPointSet, FixedTransformedPointSetType); - - /** Get/Set the moving point set. */ - itkSetConstObjectMacro(MovingPointSet, MovingPointSetType); - itkGetConstObjectMacro(MovingPointSet, MovingPointSetType); - - /** Get the moving transformed point set. */ - itkGetModifiableObjectMacro(MovingTransformedPointSet, MovingTransformedPointSetType); - - /** - * For now return the number of points used in the value/derivative calculations. - */ - SizeValueType - GetNumberOfComponents() const; - - /** - * This method returns the value of the metric based on the current - * transformation(s). This function can be redefined in derived classes - * but many point set metrics follow the same structure---one iterates - * through the points and, for each point a metric value is calculated. - * The summation of these individual point metric values gives the total - * value of the metric. Note that this might not be applicable to all - * point set metrics. For those cases, the developer will have to redefine - * the GetValue() function. - */ - MeasureType - GetValue() const override; - - /** - * This method returns the derivative based on the current - * transformation(s). This function can be redefined in derived classes - * but many point set metrics follow the same structure---one iterates - * through the points and, for each point a derivative is calculated. - * The set of all these local derivatives constitutes the total derivative. - * Note that this might not be applicable to all point set metrics. For - * those cases, the developer will have to redefine the GetDerivative() - * function. - */ - void - GetDerivative(DerivativeType &) const override; - - /** - * This method returns the derivative and value based on the current - * transformation(s). This function can be redefined in derived classes - * but many point set metrics follow the same structure---one iterates - * through the points and, for each point a derivative and value is calculated. - * The set of all these local derivatives/values constitutes the total - * derivative and value. Note that this might not be applicable to all - * point set metrics. For those cases, the developer will have to redefine - * the GetValue() and GetDerivative() functions. - */ - void - GetValueAndDerivative(MeasureType &, DerivativeType &) const override; - /** * Function to be defined in the appropriate derived classes. Calculates * the local metric value for a single point. The \c PixelType may or @@ -277,8 +179,11 @@ class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricv4 GetLocalNeighborhoodValue(const PointType &, const PixelType & pixel) const = 0; /** - * Calculates the local derivative for a single point. The \c PixelType may or - * may not be used. See class description for further explanation. + * Calculates the local derivative for a single point. + * The default implementation calls GetLocalNeighborhoodValueAndDerivative. + * The \c PixelType may or may not be used. See class + * description for further explanation. + * Default implementation calls GetLocalNeighborhoodValueAndDerivative. */ virtual LocalDerivativeType GetLocalNeighborhoodDerivative(const PointType &, const PixelType & pixel) const; @@ -293,182 +198,37 @@ class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricv4 LocalDerivativeType &, const PixelType & pixel) const = 0; - /** - * Get the virtual point set, derived from the fixed point set. - * If the virtual point set has not yet been derived, it will be - * in this call. */ - const VirtualPointSetType * - GetVirtualTransformedPointSet() const; - - /** - * Initialize the metric by making sure that all the components - * are present and plugged together correctly. - */ - void - Initialize() override; - - bool - SupportsArbitraryVirtualDomainSamples() const override - { - /* An arbitrary point in the virtual domain will not always - * correspond to a point within either point set. */ - return false; - } - - /** - * By default, the point set metric derivative for a displacement field transform - * is stored by saving the gradient for every voxel in the displacement field (see - * the function StorePointDerivative()). Since the "fixed points" will typically - * constitute a sparse set, this means that the field will have zero gradient values - * at every voxel that doesn't have a corresponding point. This might cause additional - * computation time for certain transforms (e.g. B-spline SyN). To avoid this, this - * option permits storing the point derivative only at the fixed point locations. - * If this variable is set to false, then the derivative array will be of length - * = PointDimension * m_FixedPointSet->GetNumberOfPoints(). - */ - itkSetMacro(StoreDerivativeAsSparseFieldForLocalSupportTransforms, bool); - itkGetConstMacro(StoreDerivativeAsSparseFieldForLocalSupportTransforms, bool); - itkBooleanMacro(StoreDerivativeAsSparseFieldForLocalSupportTransforms); - - /** - * - */ - itkSetMacro(CalculateValueAndDerivativeInTangentSpace, bool); - itkGetConstMacro(CalculateValueAndDerivativeInTangentSpace, bool); - itkBooleanMacro(CalculateValueAndDerivativeInTangentSpace); protected: - PointSetToPointSetMetricv4(); + PointSetToPointSetMetricv4() = default; ~PointSetToPointSetMetricv4() override = default; - void - PrintSelf(std::ostream & os, Indent indent) const override; - - typename FixedPointSetType::ConstPointer m_FixedPointSet; - mutable typename FixedTransformedPointSetType::Pointer m_FixedTransformedPointSet; - - mutable typename PointsLocatorType::Pointer m_FixedTransformedPointsLocator; - - typename MovingPointSetType::ConstPointer m_MovingPointSet; - mutable typename MovingTransformedPointSetType::Pointer m_MovingTransformedPointSet; - - mutable typename PointsLocatorType::Pointer m_MovingTransformedPointsLocator; - - /** Holds the fixed points after transformation into virtual domain. */ - mutable VirtualPointSetPointer m_VirtualTransformedPointSet; - /** - * Bool set by derived classes on whether the point set data (i.e. \c PixelType) - * should be used. Default = false. - */ - bool m_UsePointSetData; - - /** - * Flag to calculate value and/or derivative at tangent space. This is needed - * for the diffeomorphic registration methods. The fixed and moving points are - * warped to the virtual domain where the metric is calculated. Derived point - * set metrics might have associated gradient information which will need to be - * warped if this flag is true. Default = false. - */ - bool m_CalculateValueAndDerivativeInTangentSpace; - - /** - * Prepare point sets for use. */ - virtual void - InitializePointSets() const; - - /** - * Initialize to prepare for a particular iteration, generally - * an iteration of optimization. Distinct from Initialize() - * which is a one-time initialization. */ - virtual void - InitializeForIteration() const; - - /** - * Determine the number of valid fixed points. A fixed point - * is valid if, when transformed into the virtual domain using - * the inverse of the FixedTransform, it is within the defined - * virtual domain bounds. */ - virtual SizeValueType - CalculateNumberOfValidFixedPoints() const; - - /** Helper method allows for code reuse while skipping the metric value - * calculation when appropriate */ - void - CalculateValueAndDerivative(MeasureType & calculatedValue, DerivativeType & derivative, bool calculateValue) const; - - /** - * Warp the fixed point set into the moving domain based on the fixed transform, - * passing through the virtual domain and storing a virtual domain set. - * Note that the warped moving point set is of type FixedPointSetType since the transform - * takes the points from the fixed to the moving domain. - */ - void - TransformFixedAndCreateVirtualPointSet() const; - - /** - * Warp the moving point set based on the moving transform. Note that the - * warped moving point set is of type FixedPointSetType since the transform - * takes the points from the moving to the fixed domain. - * FIXME: needs update. - */ - void - TransformMovingPointSet() const; - - /** - * Build point locators for the fixed and moving point sets to speed up - * derivative and value calculations. - */ - void - InitializePointsLocators() const; - - /** - * Store a derivative from a single point in a field. - * Only relevant when active transform has local support. - */ - void - StorePointDerivative(const VirtualPointType &, const DerivativeType &, DerivativeType &) const; - - using MetricCategoryType = typename Superclass::MetricCategoryType; - - /** Get metric category */ - MetricCategoryType - GetMetricCategory() const override +private: + MeasureType + GetLocalNeighborhoodValueWithIndex(const PointIdentifier &, + const PointType & point, + const PixelType & pixel) const override { - return MetricCategoryType::POINT_SET_METRIC; - } + return this->GetLocalNeighborhoodValue(point, pixel); + }; - virtual bool - RequiresMovingPointsLocator() const + LocalDerivativeType + GetLocalNeighborhoodDerivativeWithIndex(const PointIdentifier &, + const PointType & point, + const PixelType & pixel) const override { - return true; + return this->GetLocalNeighborhoodDerivative(point, pixel); }; - virtual bool - RequiresFixedPointsLocator() const + void + GetLocalNeighborhoodValueAndDerivativeWithIndex(const PointIdentifier &, + const PointType & point, + MeasureType & measure, + LocalDerivativeType & derivative, + const PixelType & pixel) const override { - return true; + this->GetLocalNeighborhoodValueAndDerivative(point, measure, derivative, pixel); }; - -private: - mutable bool m_MovingTransformPointLocatorsNeedInitialization; - mutable bool m_FixedTransformPointLocatorsNeedInitialization; - - // Flag to keep track of whether a warning has already been issued - // regarding the number of valid points. - mutable bool m_HaveWarnedAboutNumberOfValidPoints; - - // Flag to store derivatives at fixed point locations with the rest being zero gradient - // (default = true). - bool m_StoreDerivativeAsSparseFieldForLocalSupportTransforms; - - mutable ModifiedTimeType m_MovingTransformedPointSetTime; - mutable ModifiedTimeType m_FixedTransformedPointSetTime; - - // Create ranges over the point set for multithreaded computation of value and derivatives - using PointIdentifierPair = std::pair; - using PointIdentifierRanges = std::vector; - const PointIdentifierRanges - CreateRanges() const; }; } // end namespace itk diff --git a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricv4.hxx b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricv4.hxx index bf677d1b30e..bb832664f8b 100644 --- a/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricv4.hxx +++ b/Modules/Registration/Metricsv4/include/itkPointSetToPointSetMetricv4.hxx @@ -19,460 +19,10 @@ #define itkPointSetToPointSetMetricv4_hxx #include "itkPointSetToPointSetMetricv4.h" -#include "itkIdentityTransform.h" -#include "itkCompensatedSummation.h" namespace itk { -/** Constructor */ -template -PointSetToPointSetMetricv4::PointSetToPointSetMetricv4() -{ - this->m_FixedPointSet = nullptr; // has to be provided by the user. - this->m_MovingPointSet = nullptr; // has to be provided by the user. - - this->m_FixedTransformedPointSet = nullptr; - this->m_MovingTransformedPointSet = nullptr; - this->m_VirtualTransformedPointSet = nullptr; - - this->m_FixedTransformedPointsLocator = nullptr; - this->m_MovingTransformedPointsLocator = nullptr; - - this->m_MovingTransformPointLocatorsNeedInitialization = false; - this->m_FixedTransformPointLocatorsNeedInitialization = false; - - this->m_MovingTransformedPointSetTime = this->GetMTime(); - this->m_FixedTransformedPointSetTime = this->GetMTime(); - - // We iterate over the fixed points to calculate the value and derivative. - this->SetGradientSource(ObjectToObjectMetricBaseTemplateEnums::GradientSource::GRADIENT_SOURCE_FIXED); - - this->m_HaveWarnedAboutNumberOfValidPoints = false; - - this->m_UsePointSetData = false; - - this->m_StoreDerivativeAsSparseFieldForLocalSupportTransforms = true; - - this->m_CalculateValueAndDerivativeInTangentSpace = false; -} - -/** Initialize the metric */ -template -void -PointSetToPointSetMetricv4::Initialize() -{ - if (!this->m_FixedPointSet) - { - itkExceptionMacro("Fixed point set is not present"); - } - - if (!this->m_MovingPointSet) - { - itkExceptionMacro("Moving point set is not present"); - } - - // We don't know how to support gradient source of type moving - if (this->GetGradientSourceIncludesMoving()) - { - itkExceptionMacro("GradientSource includes GRADIENT_SOURCE_MOVING. Not supported."); - } - - // If the PointSet is provided by a source, update the source. - if (this->m_MovingPointSet->GetSource()) - { - this->m_MovingPointSet->GetSource()->Update(); - } - - // If the point set is provided by a source, update the source. - if (this->m_FixedPointSet->GetSource()) - { - this->m_FixedPointSet->GetSource()->Update(); - } - - // Check for virtual domain if needed. - // With local-support transforms we need a virtual domain in - // order to properly store the per-point derivatives. - // This will create a virtual domain that matches the DisplacementFieldTransform. - // If the virutal domain has already been set, it will - // be verified against the transform in Superclass::Initialize. - if (this->HasLocalSupport()) - { - if (!this->m_UserHasSetVirtualDomain) - { - const typename DisplacementFieldTransformType::ConstPointer displacementTransform = - this->GetMovingDisplacementFieldTransform(); - if (displacementTransform.IsNull()) - { - itkExceptionMacro("Expected the moving transform to be of type DisplacementFieldTransform or derived, " - "or a CompositeTransform with DisplacementFieldTransform as the last to have been added."); - } - using DisplacementFieldType = typename DisplacementFieldTransformType::DisplacementFieldType; - typename DisplacementFieldType::ConstPointer field = displacementTransform->GetDisplacementField(); - this->SetVirtualDomain( - field->GetSpacing(), field->GetOrigin(), field->GetDirection(), field->GetBufferedRegion()); - } - } - - // Superclass initialization. Do after checking for virtual domain. - Superclass::Initialize(); - - // Call this now for derived classes that need - // a member to be initialized during Initialize(). - this->InitializePointSets(); -} - -template -void -PointSetToPointSetMetricv4::InitializePointSets() const -{ - this->TransformMovingPointSet(); - this->TransformFixedAndCreateVirtualPointSet(); - this->InitializePointsLocators(); -} - -template -void -PointSetToPointSetMetricv4::InitializeForIteration() - const -{ - this->InitializePointSets(); - this->m_NumberOfValidPoints = this->CalculateNumberOfValidFixedPoints(); - if (this->m_NumberOfValidPoints < this->GetNumberOfComponents() && !this->m_HaveWarnedAboutNumberOfValidPoints) - { - itkWarningMacro("Only " << this->m_NumberOfValidPoints << " of " << this->GetNumberOfComponents() - << " points are within the virtual domain, and will be used in the evaluation."); - this->m_HaveWarnedAboutNumberOfValidPoints = true; - } -} - -template -SizeValueType -PointSetToPointSetMetricv4::GetNumberOfComponents() - const -{ - return this->m_FixedTransformedPointSet->GetNumberOfPoints(); -} - -template -typename PointSetToPointSetMetricv4::MeasureType -PointSetToPointSetMetricv4::GetValue() const -{ - this->InitializeForIteration(); - - - // Virtual point set will be the same size as fixed point set as long as it's - // generated from the fixed point set. - if (this->m_VirtualTransformedPointSet->GetNumberOfPoints() != this->m_FixedTransformedPointSet->GetNumberOfPoints()) - { - itkExceptionMacro("Expected FixedTransformedPointSet to be the same size as VirtualTransformedPointSet."); - } - /* - * Split pointset in nWorkUnit ranges and sum individually - * This splitting is required in order to avoid having the threads - * repeatedly write to same location causing false sharing - */ - // Use STL container to make sure no unesecarry checks are performed - using FixedTransformedVectorContainer = typename FixedPointsContainer::STLContainerType; - using VirtualPointsContainer = typename VirtualPointSetType::PointsContainer; - using VirtualVectorContainer = typename VirtualPointsContainer::STLContainerType; - const VirtualVectorContainer & virtualTransformedPointSet = - this->m_VirtualTransformedPointSet->GetPoints()->CastToSTLConstContainer(); - const FixedTransformedVectorContainer & fixedTransformedPointSet = - this->m_FixedTransformedPointSet->GetPoints()->CastToSTLConstContainer(); - - PointIdentifierRanges ranges = this->CreateRanges(); - std::vector> threadValues(ranges.size()); - std::function sumNeighborhoodValues = - [this, &threadValues, &ranges, &virtualTransformedPointSet, &fixedTransformedPointSet](SizeValueType rangeIndex) { - CompensatedSummation threadValue = 0; - PixelType pixel; - NumericTraits::SetLength(pixel, 1); - - - for (PointIdentifier index = ranges[rangeIndex].first; index < ranges[rangeIndex].second; index++) - { - if (this->IsInsideVirtualDomain(virtualTransformedPointSet[index])) - { - if (this->m_UsePointSetData) - { - bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); - if (!doesPointDataExist) - { - itkExceptionMacro("The corresponding data for point (pointId = " << index << ") does not exist."); - } - } - threadValue += this->GetLocalNeighborhoodValue(fixedTransformedPointSet[index], pixel); - } - } - threadValues[rangeIndex] = threadValue; - }; - - // Sum per thread - MultiThreaderBase::New()->ParallelizeArray( - (SizeValueType)0, (SizeValueType)ranges.size(), sumNeighborhoodValues, nullptr); - // Join sums - CompensatedSummation value = 0; - for (unsigned int i = 0; i < threadValues.size(); i++) - { - value += threadValues[i]; - } - - DerivativeType derivative; - MeasureType valueSum = value.GetSum(); - if (this->VerifyNumberOfValidPoints(valueSum, derivative)) - { - valueSum /= static_cast(this->m_NumberOfValidPoints); - } - this->m_Value = valueSum; - - return valueSum; -} - -template -void -PointSetToPointSetMetricv4::GetDerivative( - DerivativeType & derivative) const -{ - MeasureType value = NumericTraits::ZeroValue(); - this->CalculateValueAndDerivative(value, derivative, false); -} - -template -void -PointSetToPointSetMetricv4::GetValueAndDerivative( - MeasureType & value, - DerivativeType & derivative) const -{ - this->CalculateValueAndDerivative(value, derivative, true); -} - -template -void -PointSetToPointSetMetricv4::CalculateValueAndDerivative( - MeasureType & calculatedValue, - DerivativeType & derivative, - bool calculateValue) const -{ - this->InitializeForIteration(); - - // Virtual point set will be the same size as fixed point set as long as it's - // generated from the fixed point set. - if (this->m_VirtualTransformedPointSet->GetNumberOfPoints() != this->m_FixedTransformedPointSet->GetNumberOfPoints()) - { - itkExceptionMacro("Expected FixedTransformedPointSet to be the same size as VirtualTransformedPointSet."); - } - - derivative.SetSize(this->GetNumberOfParameters()); - if (!this->GetStoreDerivativeAsSparseFieldForLocalSupportTransforms()) - { - derivative.SetSize(PointDimension * this->m_FixedTransformedPointSet->GetNumberOfPoints()); - } - derivative.Fill(NumericTraits::ZeroValue()); - - /* - * Split pointset in nWorkUnits ranges and sum individually - * This splitting is required in order to avoid having the threads - * repeatedly write to same location causing false sharing - */ - // GetNumberOfLocalParameters is not trhead safe in itkCompositeTransform - NumberOfParametersType numberOfLocalParameters = this->GetNumberOfLocalParameters(); - PointIdentifierRanges ranges = this->CreateRanges(); - std::vector> threadValues(ranges.size()); - using CompensatedDerivative = typename std::vector>; - std::vector threadDerivatives(ranges.size()); - std::function sumNeighborhoodValues = - [this, &derivative, &threadDerivatives, &threadValues, &ranges, &calculateValue, &numberOfLocalParameters]( - SizeValueType rangeIndex) { - // Use STL container to make sure no unesecarry checks are performed - using FixedTransformedVectorContainer = typename FixedPointsContainer::STLContainerType; - using VirtualPointsContainer = typename VirtualPointSetType::PointsContainer; - using VirtualVectorContainer = typename VirtualPointsContainer::STLContainerType; - const VirtualVectorContainer & virtualTransformedPointSet = - this->m_VirtualTransformedPointSet->GetPoints()->CastToSTLConstContainer(); - const FixedTransformedVectorContainer & fixedTransformedPointSet = - this->m_FixedTransformedPointSet->GetPoints()->CastToSTLConstContainer(); - - MovingTransformJacobianType jacobian(MovingPointDimension, numberOfLocalParameters); - MovingTransformJacobianType jacobianCache; - - DerivativeType threadLocalTransformDerivative(numberOfLocalParameters); - threadLocalTransformDerivative.Fill(NumericTraits::ZeroValue()); - - CompensatedDerivative threadDerivativeSum(numberOfLocalParameters); - - CompensatedSummation threadValue; - PixelType pixel; - NumericTraits::SetLength(pixel, 1); - for (PointIdentifier index = ranges[rangeIndex].first; index < ranges[rangeIndex].second; index++) - { - MeasureType pointValue = NumericTraits::ZeroValue(); - LocalDerivativeType pointDerivative; - - /* Verify the virtual point is in the virtual domain. - * If user hasn't defined a virtual space, and the active transform is not - * a displacement field transform type, then this will always return true. */ - if (!this->IsInsideVirtualDomain(virtualTransformedPointSet[index])) - { - continue; - } - - if (this->m_UsePointSetData) - { - bool doesPointDataExist = this->m_FixedPointSet->GetPointData(index, &pixel); - if (!doesPointDataExist) - { - itkExceptionMacro("The corresponding data for point with id " << index << " does not exist."); - } - } - - if (calculateValue) - { - this->GetLocalNeighborhoodValueAndDerivative( - fixedTransformedPointSet[index], pointValue, pointDerivative, pixel); - threadValue += pointValue; - } - else - { - pointDerivative = this->GetLocalNeighborhoodDerivative(fixedTransformedPointSet[index], pixel); - } - - // Map into parameter space - threadLocalTransformDerivative.Fill(NumericTraits::ZeroValue()); - - if (this->m_CalculateValueAndDerivativeInTangentSpace) - { - for (DimensionType d = 0; d < PointDimension; ++d) - { - threadLocalTransformDerivative[d] += pointDerivative[d]; - } - } - else - { - this->GetMovingTransform()->ComputeJacobianWithRespectToParametersCachedTemporaries( - virtualTransformedPointSet[index], jacobian, jacobianCache); - - for (NumberOfParametersType par = 0; par < numberOfLocalParameters; par++) - { - for (DimensionType d = 0; d < PointDimension; ++d) - { - threadLocalTransformDerivative[par] += jacobian(d, par) * pointDerivative[d]; - } - } - } - // For local-support transforms, store the per-point result - if (this->HasLocalSupport() || this->m_CalculateValueAndDerivativeInTangentSpace) - { - if (this->GetStoreDerivativeAsSparseFieldForLocalSupportTransforms()) - { - this->StorePointDerivative(virtualTransformedPointSet[index], threadLocalTransformDerivative, derivative); - } - else - { - for (NumberOfParametersType par = 0; par < numberOfLocalParameters; par++) - { - derivative[this->GetNumberOfLocalParameters() * index + par] = threadLocalTransformDerivative[par]; - } - } - } - for (NumberOfParametersType par = 0; par < numberOfLocalParameters; par++) - { - threadDerivativeSum[par] += threadLocalTransformDerivative[par]; - } - } - threadValues[rangeIndex] = threadValue; - threadDerivatives[rangeIndex] = threadDerivativeSum; - }; - - // Sum per thread - MultiThreaderBase::New()->ParallelizeArray( - (SizeValueType)0, (SizeValueType)ranges.size(), sumNeighborhoodValues, nullptr); - - // Sum thread results - CompensatedSummation value = 0; - for (unsigned int i = 0; i < threadValues.size(); i++) - { - value += threadValues[i]; - } - MeasureType valueSum = value.GetSum(); - - if (this->VerifyNumberOfValidPoints(valueSum, derivative)) - { - // For global-support transforms, average the accumulated derivative result - if (!this->HasLocalSupport() && !this->m_CalculateValueAndDerivativeInTangentSpace) - { - CompensatedDerivative localTransformDerivative(numberOfLocalParameters); - for (unsigned int i = 0; i < threadDerivatives.size(); i++) - { - for (NumberOfParametersType par = 0; par < numberOfLocalParameters; par++) - { - localTransformDerivative[par] += threadDerivatives[i][par]; - } - } - derivative.SetSize(numberOfLocalParameters); - for (NumberOfParametersType par = 0; par < numberOfLocalParameters; par++) - { - derivative[par] = - localTransformDerivative[par].GetSum() / static_cast(this->m_NumberOfValidPoints); - } - } - valueSum /= static_cast(this->m_NumberOfValidPoints); - } - calculatedValue = valueSum; - this->m_Value = valueSum; -} - -template -SizeValueType -PointSetToPointSetMetricv4:: - CalculateNumberOfValidFixedPoints() const -{ - // Determine the number of valid fixed points, using - // their positions in the virtual domain. - SizeValueType numberOfValidPoints = NumericTraits::ZeroValue(); - PointsConstIterator virtualIt = this->m_VirtualTransformedPointSet->GetPoints()->Begin(); - while (virtualIt != this->m_VirtualTransformedPointSet->GetPoints()->End()) - { - if (this->IsInsideVirtualDomain(virtualIt.Value())) - { - ++numberOfValidPoints; - } - ++virtualIt; - } - return numberOfValidPoints; -} - -template -void -PointSetToPointSetMetricv4::StorePointDerivative( - const VirtualPointType & virtualPoint, - const DerivativeType & pointDerivative, - DerivativeType & field) const -{ - // Update derivative field at some index. - // This requires the active transform displacement field to be the - // same size as virtual domain, and that VirtualImage PixelType - // is scalar (both of which are verified during Metric initialization). - try - { - OffsetValueType offset = - this->ComputeParameterOffsetFromVirtualPoint(virtualPoint, this->GetNumberOfLocalParameters()); - for (NumberOfParametersType i = 0; i < this->GetNumberOfLocalParameters(); i++) - { - /* Be sure to *add* here and not assign. Required for proper behavior - * with multi-variate metric. */ - field[offset + i] += pointDerivative[i]; - } - } - catch (ExceptionObject & exc) - { - std::string msg("Caught exception: \n"); - msg += exc.what(); - ExceptionObject err(__FILE__, __LINE__, msg); - throw err; - } -} - template typename PointSetToPointSetMetricv4::LocalDerivativeType PointSetToPointSetMetricv4:: @@ -484,209 +34,6 @@ PointSetToPointSetMetricv4 -void -PointSetToPointSetMetricv4::TransformMovingPointSet() - const -{ - // Transform the moving point set with the moving transform. - // We calculate the value and derivatives in the moving space. - bool update = !this->m_MovingTransformedPointSet; - update = update || this->m_MovingTransformedPointSetTime < this->GetMTime(); - update = update || (this->m_CalculateValueAndDerivativeInTangentSpace && - (this->m_MovingTransform->GetMTime() > this->m_MovingTransformedPointSetTime)); - if (update) - { - this->m_MovingTransformPointLocatorsNeedInitialization = true; - this->m_MovingTransformedPointSet = MovingTransformedPointSetType::New(); - this->m_MovingTransformedPointSet->Initialize(); - - typename MovingTransformType::InverseTransformBasePointer inverseTransform = - this->m_MovingTransform->GetInverseTransform(); - - typename MovingPointsContainer::ConstIterator It = this->m_MovingPointSet->GetPoints()->Begin(); - while (It != this->m_MovingPointSet->GetPoints()->End()) - { - if (this->m_CalculateValueAndDerivativeInTangentSpace) - { - PointType point = inverseTransform->TransformPoint(It.Value()); - this->m_MovingTransformedPointSet->SetPoint(It.Index(), point); - } - else - { - // evaluation is performed in moving space, so just copy - this->m_MovingTransformedPointSet->SetPoint(It.Index(), It.Value()); - } - ++It; - } - this->m_MovingTransformedPointSetTime = this->GetMTime(); - if (!this->m_CalculateValueAndDerivativeInTangentSpace) - { - this->m_MovingTransformedPointSetTime = - std::max(this->m_MovingTransformedPointSetTime, this->m_MovingTransform->GetMTime()); - } - } -} - -template -void -PointSetToPointSetMetricv4:: - TransformFixedAndCreateVirtualPointSet() const -{ - // Transform the fixed point set through the virtual domain, and into the moving domain - bool update = !this->m_FixedTransformedPointSet || !this->m_VirtualTransformedPointSet; - update = update || this->m_FixedTransformedPointSetTime < this->GetMTime(); - update = update || (this->m_CalculateValueAndDerivativeInTangentSpace && - (this->m_FixedTransform->GetMTime() > this->m_FixedTransformedPointSetTime)); - update = update || (!this->m_CalculateValueAndDerivativeInTangentSpace && - ((this->m_FixedTransform->GetMTime() > this->m_FixedTransformedPointSetTime) || - (this->m_MovingTransform->GetMTime() > this->m_FixedTransformedPointSetTime))); - if (update) - { - this->m_FixedTransformPointLocatorsNeedInitialization = true; - this->m_FixedTransformedPointSet = FixedTransformedPointSetType::New(); - this->m_FixedTransformedPointSet->Initialize(); - this->m_VirtualTransformedPointSet = VirtualPointSetType::New(); - this->m_VirtualTransformedPointSet->Initialize(); - - using InverseTransformBasePointer = typename FixedTransformType::InverseTransformBasePointer; - InverseTransformBasePointer inverseTransform = this->m_FixedTransform->GetInverseTransform(); - - typename FixedPointsContainer::ConstIterator It = this->m_FixedPointSet->GetPoints()->Begin(); - while (It != this->m_FixedPointSet->GetPoints()->End()) - { - if (this->m_CalculateValueAndDerivativeInTangentSpace) - { - // txf into virtual space - PointType point = inverseTransform->TransformPoint(It.Value()); - this->m_VirtualTransformedPointSet->SetPoint(It.Index(), point); - this->m_FixedTransformedPointSet->SetPoint(It.Index(), point); - } - else - { - // txf into virtual space - PointType point = inverseTransform->TransformPoint(It.Value()); - this->m_VirtualTransformedPointSet->SetPoint(It.Index(), point); - // txf into moving space - point = this->m_MovingTransform->TransformPoint(point); - this->m_FixedTransformedPointSet->SetPoint(It.Index(), point); - } - ++It; - } - this->m_FixedTransformedPointSetTime = std::max(this->GetMTime(), this->m_FixedTransform->GetMTime()); - if (!this->m_CalculateValueAndDerivativeInTangentSpace) - { - this->m_FixedTransformedPointSetTime = - std::max(this->m_FixedTransformedPointSetTime, this->m_MovingTransform->GetMTime()); - } - } -} - -template -const typename PointSetToPointSetMetricv4:: - VirtualPointSetType * - PointSetToPointSetMetricv4:: - GetVirtualTransformedPointSet() const -{ - // First make sure the virtual point set is current. - this->TransformFixedAndCreateVirtualPointSet(); - return this->m_VirtualTransformedPointSet.GetPointer(); -} - -template -void -PointSetToPointSetMetricv4::InitializePointsLocators() - const -{ - if (this->RequiresFixedPointsLocator() && this->m_FixedTransformPointLocatorsNeedInitialization) - { - if (!this->m_FixedTransformedPointSet) - { - itkExceptionMacro("The fixed transformed point set does not exist."); - } - if (!this->m_FixedTransformedPointsLocator) - { - this->m_FixedTransformedPointsLocator = PointsLocatorType::New(); - } - this->m_FixedTransformedPointsLocator->SetPoints(this->m_FixedTransformedPointSet->GetPoints()); - this->m_FixedTransformedPointsLocator->Initialize(); - this->m_FixedTransformPointLocatorsNeedInitialization = false; - } - - if (this->RequiresMovingPointsLocator() && this->m_MovingTransformPointLocatorsNeedInitialization) - { - if (!this->m_MovingTransformedPointSet) - { - itkExceptionMacro("The moving transformed point set does not exist."); - } - if (!this->m_MovingTransformedPointsLocator) - { - this->m_MovingTransformedPointsLocator = PointsLocatorType::New(); - } - this->m_MovingTransformedPointsLocator->SetPoints(this->m_MovingTransformedPointSet->GetPoints()); - this->m_MovingTransformedPointsLocator->Initialize(); - this->m_MovingTransformPointLocatorsNeedInitialization = false; - } -} - -template -const typename PointSetToPointSetMetricv4:: - PointIdentifierRanges - PointSetToPointSetMetricv4::CreateRanges() const -{ - PointIdentifier nPoints = this->m_FixedTransformedPointSet->GetNumberOfPoints(); - PointIdentifier nWorkUnits = MultiThreaderBase::New()->GetNumberOfWorkUnits(); - if (nWorkUnits > nPoints || MultiThreaderBase::New()->GetMaximumNumberOfThreads() <= 1) - { - nWorkUnits = 1; - } - PointIdentifier startRange = 0; - PointIdentifierRanges ranges; - for (PointIdentifier p = 1; p < nWorkUnits; ++p) - { - PointIdentifier endRange = (p * nPoints) / (double)nWorkUnits; - ranges.push_back(PointIdentifierPair(startRange, endRange)); - startRange = endRange; - } - ranges.push_back(PointIdentifierPair(startRange, nPoints)); - - return ranges; -} - - -/** PrintSelf */ -template -void -PointSetToPointSetMetricv4::PrintSelf( - std::ostream & os, - Indent indent) const -{ - Superclass::PrintSelf(os, indent); - os << indent << "Fixed PointSet: " << this->m_FixedPointSet.GetPointer() << std::endl; - os << indent << "Fixed Transform: " << this->m_FixedTransform.GetPointer() << std::endl; - os << indent << "Moving PointSet: " << this->m_MovingPointSet.GetPointer() << std::endl; - os << indent << "Moving Transform: " << this->m_MovingTransform.GetPointer() << std::endl; - - os << indent << "Store derivative as sparse field = "; - if (this->m_StoreDerivativeAsSparseFieldForLocalSupportTransforms) - { - os << "true." << std::endl; - } - else - { - os << "false." << std::endl; - } - - os << indent << "Calculate in tangent space = "; - if (this->m_CalculateValueAndDerivativeInTangentSpace) - { - os << "true." << std::endl; - } - else - { - os << "false." << std::endl; - } -} } // end namespace itk #endif diff --git a/Modules/Registration/Metricsv4/wrapping/itkEuclideanDistancePointSetToPointSetMetric.wrap b/Modules/Registration/Metricsv4/wrapping/itkEuclideanDistancePointSetToPointSetMetric.wrap index 518a749afda..2c4e46d0536 100644 --- a/Modules/Registration/Metricsv4/wrapping/itkEuclideanDistancePointSetToPointSetMetric.wrap +++ b/Modules/Registration/Metricsv4/wrapping/itkEuclideanDistancePointSetToPointSetMetric.wrap @@ -3,7 +3,7 @@ itk_wrap_include("itkDefaultStaticMeshTraits.h") UNIQUE(types "${WRAP_ITK_SCALAR};D") -itk_wrap_class("itk::EuclideanDistancePointSetToPointSetMetricv4" POINTER_WITH_SUPERCLASS) +itk_wrap_class("itk::EuclideanDistancePointSetToPointSetMetricv4" POINTER_WITH_2_SUPERCLASSES) foreach(d ${ITK_WRAP_IMAGE_DIMS}) foreach(t ${types}) itk_wrap_template("PS${ITKM_${t}}${d}" "itk::PointSet< ${ITKT_${t}},${d} >") diff --git a/Modules/Registration/Metricsv4/wrapping/itkExpectationBasedPointSetToPointSetMetricv4.wrap b/Modules/Registration/Metricsv4/wrapping/itkExpectationBasedPointSetToPointSetMetricv4.wrap index c869c190e45..20b3836d6f8 100644 --- a/Modules/Registration/Metricsv4/wrapping/itkExpectationBasedPointSetToPointSetMetricv4.wrap +++ b/Modules/Registration/Metricsv4/wrapping/itkExpectationBasedPointSetToPointSetMetricv4.wrap @@ -3,7 +3,7 @@ itk_wrap_include("itkDefaultStaticMeshTraits.h") UNIQUE(types "${WRAP_ITK_SCALAR};D") -itk_wrap_class("itk::ExpectationBasedPointSetToPointSetMetricv4" POINTER) +itk_wrap_class("itk::ExpectationBasedPointSetToPointSetMetricv4" POINTER_WITH_2_SUPERCLASSES) foreach(d ${ITK_WRAP_IMAGE_DIMS}) foreach(t ${types}) itk_wrap_template("PS${ITKM_${t}}${d}" "itk::PointSet< ${ITKT_${t}},${d} >") diff --git a/Modules/Registration/Metricsv4/wrapping/itkJensenHavrdaCharvatTsallisPointSetToPointSetMetricv4.wrap b/Modules/Registration/Metricsv4/wrapping/itkJensenHavrdaCharvatTsallisPointSetToPointSetMetricv4.wrap index fadbb782211..b52b08f32e7 100644 --- a/Modules/Registration/Metricsv4/wrapping/itkJensenHavrdaCharvatTsallisPointSetToPointSetMetricv4.wrap +++ b/Modules/Registration/Metricsv4/wrapping/itkJensenHavrdaCharvatTsallisPointSetToPointSetMetricv4.wrap @@ -3,7 +3,7 @@ itk_wrap_include("itkDefaultStaticMeshTraits.h") UNIQUE(types "${WRAP_ITK_SCALAR};D") -itk_wrap_class("itk::JensenHavrdaCharvatTsallisPointSetToPointSetMetricv4" POINTER) +itk_wrap_class("itk::JensenHavrdaCharvatTsallisPointSetToPointSetMetricv4" POINTER_WITH_2_SUPERCLASSES) foreach(d ${ITK_WRAP_IMAGE_DIMS}) foreach(t ${types}) itk_wrap_template("PS${ITKM_${t}}${d}" "itk::PointSet< ${ITKT_${t}},${d} >") diff --git a/Modules/Registration/Metricsv4/wrapping/itkLabeledPointSetToPointSetMetricv4.wrap b/Modules/Registration/Metricsv4/wrapping/itkLabeledPointSetToPointSetMetricv4.wrap index 2cb8c1ce093..8c2cdec7b0d 100644 --- a/Modules/Registration/Metricsv4/wrapping/itkLabeledPointSetToPointSetMetricv4.wrap +++ b/Modules/Registration/Metricsv4/wrapping/itkLabeledPointSetToPointSetMetricv4.wrap @@ -3,7 +3,7 @@ itk_wrap_include("itkDefaultStaticMeshTraits.h") UNIQUE(types "${WRAP_ITK_INT}") -itk_wrap_class("itk::LabeledPointSetToPointSetMetricv4" POINTER) +itk_wrap_class("itk::LabeledPointSetToPointSetMetricv4" POINTER_WITH_2_SUPERCLASSES) foreach(d ${ITK_WRAP_IMAGE_DIMS}) foreach(t ${types}) itk_wrap_template("PS${ITKM_${t}}${d}" "itk::PointSet< ${ITKT_${t}},${d} >") diff --git a/Modules/Registration/Metricsv4/wrapping/itkPointSetToPointSetMetricWithIndexv4.wrap b/Modules/Registration/Metricsv4/wrapping/itkPointSetToPointSetMetricWithIndexv4.wrap new file mode 100644 index 00000000000..2c2b11b1b42 --- /dev/null +++ b/Modules/Registration/Metricsv4/wrapping/itkPointSetToPointSetMetricWithIndexv4.wrap @@ -0,0 +1,12 @@ +itk_wrap_include("itkPointSet.h") +itk_wrap_include("itkDefaultStaticMeshTraits.h") + +UNIQUE(types "${WRAP_ITK_SCALAR};D") + +itk_wrap_class("itk::PointSetToPointSetMetricWithIndexv4" POINTER_WITH_SUPERCLASS) + foreach(d ${ITK_WRAP_IMAGE_DIMS}) + foreach(t ${types}) + itk_wrap_template("PS${ITKM_${t}}${d}" "itk::PointSet< ${ITKT_${t}},${d} >") + endforeach() + endforeach() +itk_end_wrap_class() diff --git a/Modules/Registration/Metricsv4/wrapping/itkPointSetToPointSetMetricv4.wrap b/Modules/Registration/Metricsv4/wrapping/itkPointSetToPointSetMetricv4.wrap new file mode 100644 index 00000000000..504cf500d4a --- /dev/null +++ b/Modules/Registration/Metricsv4/wrapping/itkPointSetToPointSetMetricv4.wrap @@ -0,0 +1,12 @@ +itk_wrap_include("itkPointSet.h") +itk_wrap_include("itkDefaultStaticMeshTraits.h") + +UNIQUE(types "${WRAP_ITK_SCALAR};D") + +itk_wrap_class("itk::PointSetToPointSetMetricv4" POINTER_WITH_2_SUPERCLASSES) + foreach(d ${ITK_WRAP_IMAGE_DIMS}) + foreach(t ${types}) + itk_wrap_template("PS${ITKM_${t}}${d}" "itk::PointSet< ${ITKT_${t}},${d} >") + endforeach() + endforeach() +itk_end_wrap_class() diff --git a/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.h b/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.h index f13f3f14bd2..6fd6edcf6cf 100644 --- a/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.h +++ b/Modules/Registration/RegistrationMethodsv4/include/itkImageRegistrationMethodv4.h @@ -26,7 +26,7 @@ #include "itkObjectToObjectMultiMetricv4.h" #include "itkObjectToObjectOptimizerBase.h" #include "itkImageToImageMetricv4.h" -#include "itkPointSetToPointSetMetricv4.h" +#include "itkPointSetToPointSetMetricWithIndexv4.h" #include "itkShrinkImageFilter.h" #include "itkIdentityTransform.h" #include "itkTransformParametersAdaptorBase.h" @@ -173,7 +173,7 @@ class ITK_TEMPLATE_EXPORT ImageRegistrationMethodv4 : public ProcessObject using MultiMetricType = ObjectToObjectMultiMetricv4; using ImageMetricType = ImageToImageMetricv4; - using PointSetMetricType = PointSetToPointSetMetricv4; + using PointSetMetricType = PointSetToPointSetMetricWithIndexv4; using FixedImageMaskType = typename ImageMetricType::FixedImageMaskType; using FixedImageMaskConstPointer = typename FixedImageMaskType::ConstPointer;