00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkKdTreeBasedKmeansEstimator_h
00018 #define __itkKdTreeBasedKmeansEstimator_h
00019
00020 #include <vector>
00021 #include "itk_hash_map.h"
00022
00023 #include "itkObject.h"
00024 #include "itkMeasurementVectorTraits.h"
00025
00026 namespace itk {
00027 namespace Statistics {
00028
00066 template< class TKdTree >
00067 class ITK_EXPORT KdTreeBasedKmeansEstimator:
00068 public Object
00069 {
00070 public:
00072 typedef KdTreeBasedKmeansEstimator Self ;
00073 typedef Object Superclass;
00074 typedef SmartPointer<Self> Pointer;
00075 typedef SmartPointer<const Self> ConstPointer;
00076
00078 itkNewMacro(Self);
00079
00081 itkTypeMacro(KdTreeBasedKmeansEstimator, Obeject);
00082
00084 typedef typename TKdTree::KdTreeNodeType KdTreeNodeType ;
00085 typedef typename TKdTree::MeasurementType MeasurementType ;
00086 typedef typename TKdTree::MeasurementVectorType MeasurementVectorType ;
00087 typedef typename TKdTree::InstanceIdentifier InstanceIdentifier ;
00088 typedef typename TKdTree::SampleType SampleType ;
00089 typedef typename KdTreeNodeType::CentroidType CentroidType ;
00090
00091
00093 typedef unsigned int MeasurementVectorSizeType;
00094
00097 typedef Array< double > ParameterType ;
00098 typedef std::vector< ParameterType > InternalParametersType;
00099 typedef Array< double > ParametersType;
00100
00102 void SetParameters(ParametersType& params)
00103 { m_Parameters = params ; }
00104
00106 ParametersType& GetParameters()
00107 { return m_Parameters ; }
00108
00110 itkSetMacro( MaximumIteration, int );
00111 itkGetConstReferenceMacro( MaximumIteration, int );
00112
00115 itkSetMacro( CentroidPositionChangesThreshold, double );
00116 itkGetConstReferenceMacro( CentroidPositionChangesThreshold, double );
00117
00119 void SetKdTree(TKdTree* tree)
00120 {
00121 m_KdTree = tree ;
00122 m_MeasurementVectorSize = tree->GetMeasurementVectorSize();
00123 m_DistanceMetric->SetMeasurementVectorSize( m_MeasurementVectorSize );
00124 MeasurementVectorTraits::SetLength( m_TempVertex, m_MeasurementVectorSize );
00125 }
00126
00127 TKdTree* GetKdTree()
00128 { return m_KdTree.GetPointer() ; }
00129
00131 itkGetConstReferenceMacro( MeasurementVectorSize, MeasurementVectorSizeType );
00132
00133 itkGetConstReferenceMacro( CurrentIteration, int) ;
00134 itkGetConstReferenceMacro( CentroidPositionChanges, double) ;
00135
00140 void StartOptimization() ;
00141
00142 typedef itk::hash_map< InstanceIdentifier, unsigned int > ClusterLabelsType ;
00143
00144 void SetUseClusterLabels(bool flag)
00145 { m_UseClusterLabels = flag ; }
00146
00147 ClusterLabelsType* GetClusterLabels()
00148 { return &m_ClusterLabels ; }
00149
00150 protected:
00151 KdTreeBasedKmeansEstimator() ;
00152 virtual ~KdTreeBasedKmeansEstimator() {}
00153
00154 void PrintSelf(std::ostream& os, Indent indent) const;
00155
00156 void FillClusterLabels(KdTreeNodeType* node, int closestIndex) ;
00157
00159 class CandidateVector
00160 {
00161 public:
00162 CandidateVector() {}
00163
00164 struct Candidate
00165 {
00166 CentroidType Centroid ;
00167 CentroidType WeightedCentroid ;
00168 int Size ;
00169 } ;
00170
00171 virtual ~CandidateVector() {}
00172
00174 int Size() const
00175 { return static_cast<int>( m_Candidates.size() ); }
00176
00179 void SetCentroids(InternalParametersType& centroids)
00180 {
00181 this->m_MeasurementVectorSize = MeasurementVectorTraits::GetLength( centroids[0] );
00182 m_Candidates.resize(centroids.size()) ;
00183 for (unsigned int i = 0 ; i < centroids.size() ; i++)
00184 {
00185 Candidate candidate ;
00186 candidate.Centroid = centroids[i] ;
00187 MeasurementVectorTraits::SetLength( candidate.WeightedCentroid, m_MeasurementVectorSize );
00188 candidate.WeightedCentroid.Fill(0.0) ;
00189 candidate.Size = 0 ;
00190 m_Candidates[i] = candidate ;
00191 }
00192 }
00193
00195 void GetCentroids(InternalParametersType& centroids)
00196 {
00197 unsigned int i ;
00198 centroids.resize(this->Size()) ;
00199 for (i = 0 ; i < (unsigned int)this->Size() ; i++)
00200 {
00201 centroids[i] = m_Candidates[i].Centroid ;
00202 }
00203 }
00204
00207 void UpdateCentroids()
00208 {
00209 unsigned int i, j ;
00210 for (i = 0 ; i < (unsigned int)this->Size() ; i++)
00211 {
00212 if (m_Candidates[i].Size > 0)
00213 {
00214 for (j = 0 ; j < m_MeasurementVectorSize; j++)
00215 {
00216 m_Candidates[i].Centroid[j] =
00217 m_Candidates[i].WeightedCentroid[j] /
00218 double(m_Candidates[i].Size) ;
00219 }
00220 }
00221 }
00222 }
00223
00225 Candidate& operator[](int index)
00226 { return m_Candidates[index] ; }
00227
00228
00229 private:
00231 std::vector< Candidate > m_Candidates ;
00232
00234 MeasurementVectorSizeType m_MeasurementVectorSize;
00235 } ;
00236
00242 double GetSumOfSquaredPositionChanges(InternalParametersType &previous,
00243 InternalParametersType ¤t) ;
00244
00247 int GetClosestCandidate(ParameterType &measurements,
00248 std::vector< int > &validIndexes) ;
00249
00251 bool IsFarther(ParameterType &pointA,
00252 ParameterType &pointB,
00253 MeasurementVectorType &lowerBound,
00254 MeasurementVectorType &upperBound) ;
00255
00258 void Filter(KdTreeNodeType* node,
00259 std::vector< int > validIndexes,
00260 MeasurementVectorType &lowerBound,
00261 MeasurementVectorType &upperBound) ;
00262
00264 void CopyParameters(InternalParametersType &source, InternalParametersType &target) ;
00265
00267 void CopyParameters(ParametersType &source, InternalParametersType &target) ;
00268
00270 void CopyParameters(InternalParametersType &source, ParametersType &target) ;
00271
00273 void GetPoint(ParameterType &point,
00274 MeasurementVectorType measurements)
00275 {
00276 for (unsigned int i = 0 ; i < m_MeasurementVectorSize ; i++)
00277 {
00278 point[i] = measurements[i] ;
00279 }
00280 }
00281
00282 void PrintPoint(ParameterType &point)
00283 {
00284 std::cout << "[ " ;
00285 for (unsigned int i = 0 ; i < m_MeasurementVectorSize ; i++)
00286 {
00287 std::cout << point[i] << " " ;
00288 }
00289 std::cout << "]" ;
00290 }
00291
00292 private:
00294 int m_CurrentIteration ;
00296 int m_MaximumIteration ;
00298 double m_CentroidPositionChanges ;
00301 double m_CentroidPositionChangesThreshold ;
00303 typename TKdTree::Pointer m_KdTree ;
00305 typename EuclideanDistance< ParameterType >::Pointer m_DistanceMetric ;
00306
00308 ParametersType m_Parameters ;
00309
00310 CandidateVector m_CandidateVector ;
00311
00312 ParameterType m_TempVertex ;
00313
00314 bool m_UseClusterLabels ;
00315 bool m_GenerateClusterLabels ;
00316 ClusterLabelsType m_ClusterLabels ;
00317 MeasurementVectorSizeType m_MeasurementVectorSize;
00318 } ;
00319
00320 }
00321 }
00322
00323 #ifndef ITK_MANUAL_INSTANTIATION
00324 #include "itkKdTreeBasedKmeansEstimator.txx"
00325 #endif
00326
00327
00328 #endif