00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkKdTree_h
00018 #define __itkKdTree_h
00019
00020 #include <queue>
00021 #include <vector>
00022
00023 #include "itkMacro.h"
00024 #include "itkPoint.h"
00025 #include "itkSize.h"
00026 #include "itkObject.h"
00027 #include "itkNumericTraits.h"
00028 #include "itkArray.h"
00029
00030 #include "itkSample.h"
00031 #include "itkSubsample.h"
00032
00033 #include "itkEuclideanDistance.h"
00034
00035 namespace itk{
00036 namespace Statistics{
00037
00062 template< class TSample >
00063 struct KdTreeNode
00064 {
00066 typedef KdTreeNode< TSample> Self ;
00067
00069 typedef typename TSample::MeasurementType MeasurementType ;
00070
00072 typedef Array< double > CentroidType;
00073
00076 typedef typename TSample::InstanceIdentifier InstanceIdentifier ;
00077
00080 virtual bool IsTerminal() const = 0 ;
00081
00087 virtual void GetParameters(unsigned int &partitionDimension,
00088 MeasurementType &partitionValue) const = 0 ;
00089
00091 virtual Self* Left() = 0 ;
00092 virtual const Self* Left() const = 0 ;
00093
00095 virtual Self* Right() = 0 ;
00096 virtual const Self* Right() const = 0 ;
00097
00100 virtual unsigned int Size() const = 0 ;
00101
00103 virtual void GetWeightedCentroid(CentroidType ¢roid) = 0 ;
00104
00106 virtual void GetCentroid(CentroidType ¢roid) = 0 ;
00107
00109 virtual InstanceIdentifier GetInstanceIdentifier(size_t index) const = 0 ;
00110
00112 virtual void AddInstanceIdentifier(InstanceIdentifier id) = 0 ;
00113
00115 virtual ~KdTreeNode() {};
00116 } ;
00117
00129 template< class TSample >
00130 struct KdTreeNonterminalNode: public KdTreeNode< TSample >
00131 {
00132 typedef KdTreeNode< TSample > Superclass ;
00133 typedef typename Superclass::MeasurementType MeasurementType ;
00134 typedef typename Superclass::CentroidType CentroidType ;
00135 typedef typename Superclass::InstanceIdentifier InstanceIdentifier ;
00136
00137 KdTreeNonterminalNode(unsigned int partitionDimension,
00138 MeasurementType partitionValue,
00139 Superclass* left,
00140 Superclass* right) ;
00141
00142 virtual ~KdTreeNonterminalNode() {}
00143
00144 virtual bool IsTerminal() const
00145 { return false ; }
00146
00147 void GetParameters(unsigned int &partitionDimension,
00148 MeasurementType &partitionValue) const;
00149
00150 Superclass* Left()
00151 { return m_Left ; }
00152
00153 Superclass* Right()
00154 { return m_Right ; }
00155
00156 const Superclass* Left() const
00157 { return m_Left ; }
00158
00159 const Superclass* Right() const
00160 { return m_Right ; }
00161
00162 unsigned int Size() const
00163 { return 0 ; }
00164
00165 void GetWeightedCentroid(CentroidType &)
00166 { }
00167
00168 void GetCentroid(CentroidType &)
00169 { }
00170
00171 InstanceIdentifier GetInstanceIdentifier(size_t) const
00172 { return 0 ; }
00173
00174 void AddInstanceIdentifier(InstanceIdentifier) {}
00175
00176 private:
00177 unsigned int m_PartitionDimension ;
00178 MeasurementType m_PartitionValue ;
00179 Superclass* m_Left ;
00180 Superclass* m_Right ;
00181 } ;
00182
00197 template< class TSample >
00198 struct KdTreeWeightedCentroidNonterminalNode: public KdTreeNode< TSample >
00199 {
00200 typedef KdTreeNode< TSample > Superclass ;
00201 typedef typename Superclass::MeasurementType MeasurementType ;
00202 typedef typename Superclass::CentroidType CentroidType ;
00203 typedef typename Superclass::InstanceIdentifier InstanceIdentifier ;
00204 typedef typename TSample::MeasurementVectorSizeType MeasurementVectorSizeType;
00205
00206 KdTreeWeightedCentroidNonterminalNode(unsigned int partitionDimension,
00207 MeasurementType partitionValue,
00208 Superclass* left,
00209 Superclass* right,
00210 CentroidType ¢roid,
00211 unsigned int size) ;
00212 virtual ~KdTreeWeightedCentroidNonterminalNode() {}
00213
00214 virtual bool IsTerminal() const
00215 { return false ; }
00216
00217 void GetParameters(unsigned int &partitionDimension,
00218 MeasurementType &partitionValue) const ;
00219
00221 MeasurementVectorSizeType GetMeasurementVectorSize() const
00222 {
00223 return m_MeasurementVectorSize;
00224 }
00225
00226 Superclass* Left()
00227 { return m_Left ; }
00228
00229 Superclass* Right()
00230 { return m_Right ; }
00231
00232
00233 const Superclass* Left() const
00234 { return m_Left ; }
00235
00236 const Superclass* Right() const
00237 { return m_Right ; }
00238
00239 unsigned int Size() const
00240 { return m_Size ; }
00241
00242 void GetWeightedCentroid(CentroidType ¢roid)
00243 { centroid = m_WeightedCentroid ; }
00244
00245 void GetCentroid(CentroidType ¢roid)
00246 { centroid = m_Centroid ; }
00247
00248 InstanceIdentifier GetInstanceIdentifier(size_t) const
00249 { return 0 ; }
00250
00251 void AddInstanceIdentifier(InstanceIdentifier) {}
00252
00253 private:
00254 MeasurementVectorSizeType m_MeasurementVectorSize;
00255 unsigned int m_PartitionDimension ;
00256 MeasurementType m_PartitionValue ;
00257 CentroidType m_WeightedCentroid ;
00258 CentroidType m_Centroid ;
00259 unsigned int m_Size ;
00260 Superclass* m_Left ;
00261 Superclass* m_Right ;
00262 } ;
00263
00264
00276 template< class TSample >
00277 struct KdTreeTerminalNode: public KdTreeNode< TSample >
00278 {
00279 typedef KdTreeNode< TSample > Superclass ;
00280 typedef typename Superclass::MeasurementType MeasurementType ;
00281 typedef typename Superclass::CentroidType CentroidType ;
00282 typedef typename Superclass::InstanceIdentifier InstanceIdentifier ;
00283
00284 KdTreeTerminalNode() {}
00285
00286 virtual ~KdTreeTerminalNode() {}
00287
00288 bool IsTerminal() const
00289 { return true ; }
00290
00291 void GetParameters(unsigned int &,
00292 MeasurementType &) const {}
00293
00294 Superclass* Left()
00295 { return 0 ; }
00296
00297 Superclass* Right()
00298 { return 0 ; }
00299
00300
00301 const Superclass* Left() const
00302 { return 0 ; }
00303
00304 const Superclass* Right() const
00305 { return 0 ; }
00306
00307 unsigned int Size() const
00308 { return static_cast<unsigned int>( m_InstanceIdentifiers.size() ); }
00309
00310 void GetWeightedCentroid(CentroidType &)
00311 { }
00312
00313 void GetCentroid(CentroidType &)
00314 { }
00315
00316 InstanceIdentifier GetInstanceIdentifier(size_t index) const
00317 { return m_InstanceIdentifiers[index] ; }
00318
00319 void AddInstanceIdentifier(InstanceIdentifier id)
00320 { m_InstanceIdentifiers.push_back(id) ;}
00321
00322 private:
00323 std::vector< InstanceIdentifier > m_InstanceIdentifiers ;
00324 } ;
00325
00358 template < class TSample >
00359 class ITK_EXPORT KdTree : public Object
00360 {
00361 public:
00363 typedef KdTree Self ;
00364 typedef Object Superclass ;
00365 typedef SmartPointer<Self> Pointer;
00366 typedef SmartPointer<const Self> ConstPointer;
00367
00369 itkTypeMacro(KdTree, Object);
00370
00372 itkNewMacro(Self) ;
00373
00375 typedef TSample SampleType ;
00376 typedef typename TSample::MeasurementVectorType MeasurementVectorType ;
00377 typedef typename TSample::MeasurementType MeasurementType ;
00378 typedef typename TSample::InstanceIdentifier InstanceIdentifier ;
00379 typedef typename TSample::FrequencyType FrequencyType ;
00380
00381 typedef unsigned int MeasurementVectorSizeType;
00382
00385 itkGetConstMacro( MeasurementVectorSize, MeasurementVectorSizeType );
00386
00388 typedef EuclideanDistance< MeasurementVectorType > DistanceMetricType ;
00389
00391 typedef KdTreeNode< TSample > KdTreeNodeType ;
00392
00396 typedef std::pair< InstanceIdentifier, double > NeighborType ;
00397
00398 typedef std::vector< InstanceIdentifier > InstanceIdentifierVectorType ;
00399
00409 class NearestNeighbors
00410 {
00411 public:
00413 NearestNeighbors() {}
00414
00416 ~NearestNeighbors() {}
00417
00420 void resize(unsigned int k)
00421 {
00422 m_Identifiers.clear() ;
00423 m_Identifiers.resize(k, NumericTraits< unsigned long >::max()) ;
00424 m_Distances.clear() ;
00425 m_Distances.resize(k, NumericTraits< double >::max()) ;
00426 m_FarthestNeighborIndex = 0 ;
00427 }
00428
00430 double GetLargestDistance()
00431 { return m_Distances[m_FarthestNeighborIndex] ; }
00432
00435 void ReplaceFarthestNeighbor(InstanceIdentifier id, double distance)
00436 {
00437 m_Identifiers[m_FarthestNeighborIndex] = id ;
00438 m_Distances[m_FarthestNeighborIndex] = distance ;
00439 double farthestDistance = NumericTraits< double >::min() ;
00440 const unsigned int size = static_cast<unsigned int>( m_Distances.size() );
00441 for ( unsigned int i = 0 ; i < size; i++ )
00442 {
00443 if ( m_Distances[i] > farthestDistance )
00444 {
00445 farthestDistance = m_Distances[i] ;
00446 m_FarthestNeighborIndex = i ;
00447 }
00448 }
00449 }
00450
00452 InstanceIdentifierVectorType GetNeighbors()
00453 { return m_Identifiers ; }
00454
00457 InstanceIdentifier GetNeighbor(unsigned int index)
00458 { return m_Identifiers[index] ; }
00459
00461 std::vector< double >& GetDistances()
00462 { return m_Distances ; }
00463
00464 private:
00466 unsigned int m_FarthestNeighborIndex ;
00467
00469 InstanceIdentifierVectorType m_Identifiers ;
00470
00473 std::vector< double > m_Distances ;
00474 } ;
00475
00478 void SetBucketSize(unsigned int size) ;
00479
00482 void SetSample(const TSample* sample) ;
00483
00485 const TSample* GetSample() const
00486 { return m_Sample ; }
00487
00488 unsigned long Size() const
00489 { return m_Sample->Size() ; }
00490
00495 KdTreeNodeType* GetEmptyTerminalNode()
00496 { return m_EmptyTerminalNode ; }
00497
00500 void SetRoot(KdTreeNodeType* root)
00501 { m_Root = root ; }
00502
00504 KdTreeNodeType* GetRoot()
00505 { return m_Root ; }
00506
00509 const MeasurementVectorType & GetMeasurementVector(InstanceIdentifier id) const
00510 { return m_Sample->GetMeasurementVector(id) ; }
00511
00514 FrequencyType GetFrequency(InstanceIdentifier id) const
00515 { return m_Sample->GetFrequency( id ) ; }
00516
00518 DistanceMetricType* GetDistanceMetric()
00519 { return m_DistanceMetric.GetPointer() ; }
00520
00522 void Search(MeasurementVectorType &query,
00523 unsigned int k,
00524 InstanceIdentifierVectorType& result) const;
00525
00527 void Search(MeasurementVectorType &query,
00528 double radius,
00529 InstanceIdentifierVectorType& result) const;
00530
00533 int GetNumberOfVisits() const
00534 { return m_NumberOfVisits ; }
00535
00541 bool BallWithinBounds(MeasurementVectorType &query,
00542 MeasurementVectorType &lowerBound,
00543 MeasurementVectorType &upperBound,
00544 double radius) const ;
00545
00549 bool BoundsOverlapBall(MeasurementVectorType &query,
00550 MeasurementVectorType &lowerBound,
00551 MeasurementVectorType &upperBound,
00552 double radius) const ;
00553
00555 void DeleteNode(KdTreeNodeType *node) ;
00556
00558 void PrintTree(KdTreeNodeType *node, int level,
00559 unsigned int activeDimension) ;
00560
00561 typedef typename TSample::Iterator Iterator ;
00562 typedef typename TSample::ConstIterator ConstIterator ;
00563
00564 Iterator Begin()
00565 {
00566 typename TSample::ConstIterator iter = m_Sample->Begin() ;
00567 return iter;
00568 }
00569
00570 Iterator End()
00571 {
00572 Iterator iter = m_Sample->End() ;
00573 return iter;
00574 }
00575
00576 ConstIterator Begin() const
00577 {
00578 typename TSample::ConstIterator iter = m_Sample->Begin() ;
00579 return iter;
00580 }
00581
00582 ConstIterator End() const
00583 {
00584 ConstIterator iter = m_Sample->End() ;
00585 return iter;
00586 }
00587
00588
00589 protected:
00591 KdTree() ;
00592
00594 virtual ~KdTree() ;
00595
00596 void PrintSelf(std::ostream& os, Indent indent) const ;
00597
00599 int NearestNeighborSearchLoop(const KdTreeNodeType* node,
00600 MeasurementVectorType &query,
00601 MeasurementVectorType &lowerBound,
00602 MeasurementVectorType &upperBound) const;
00603
00605 int SearchLoop(const KdTreeNodeType* node, MeasurementVectorType &query,
00606 MeasurementVectorType &lowerBound,
00607 MeasurementVectorType &upperBound) const ;
00608 private:
00609 KdTree(const Self&) ;
00610 void operator=(const Self&) ;
00611
00613 const TSample* m_Sample ;
00614
00616 int m_BucketSize ;
00617
00619 KdTreeNodeType* m_Root ;
00620
00622 KdTreeNodeType* m_EmptyTerminalNode ;
00623
00625 typename DistanceMetricType::Pointer m_DistanceMetric ;
00626
00627 mutable bool m_IsNearestNeighborSearch ;
00628
00629 mutable double m_SearchRadius ;
00630
00631 mutable InstanceIdentifierVectorType m_Neighbors ;
00632
00634 mutable NearestNeighbors m_NearestNeighbors ;
00635
00637 mutable MeasurementVectorType m_LowerBound ;
00638
00640 mutable MeasurementVectorType m_UpperBound ;
00641
00643 mutable int m_NumberOfVisits ;
00644
00646 mutable bool m_StopSearch ;
00647
00649 mutable NeighborType m_TempNeighbor ;
00650
00652 MeasurementVectorSizeType m_MeasurementVectorSize;
00653 } ;
00654
00655 }
00656 }
00657
00658 #ifndef ITK_MANUAL_INSTANTIATION
00659 #include "itkKdTree.txx"
00660 #endif
00661
00662 #endif
00663
00664
00665