00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkRBFLayerBase_h
00018 #define __itkRBFLayerBase_h
00019
00020 #include "itkLayerBase.h"
00021 #include "itkObject.h"
00022 #include "itkMacro.h"
00023 #include "itkRadialBasisFunctionBase.h"
00024 #include "itkEuclideanDistance.h"
00025
00026 namespace itk
00027 {
00028 namespace Statistics
00029 {
00030
00031 template<class TVector, class TOutput>
00032 class RBFLayer : public LayerBase<TVector, TOutput>
00033 {
00034 public:
00035
00036 typedef RBFLayer Self;
00037 typedef LayerBase<TVector, TOutput> Superclass;
00038 typedef SmartPointer<Self> Pointer;
00039 typedef SmartPointer<const Self> ConstPointer;
00040
00042 itkTypeMacro(RBFLayer, LayerBase);
00043 itkNewMacro(Self) ;
00044
00045 typedef typename Superclass::ValueType ValueType;
00046 typedef typename Superclass::ValuePointer ValuePointer;
00047 typedef vnl_vector<ValueType> NodeVectorType;
00048 typedef Array<ValueType> NodeArrayType;
00049 typedef typename Superclass::OutputVectorType OutputVectorType;
00050
00051 typedef RadialBasisFunctionBase<ValueType> RBFType;
00052
00053
00054 typedef EuclideanDistance<TVector> DistanceMetricType;
00055 typedef typename DistanceMetricType::Pointer DistanceMetricPointer;
00056
00057 void SetNumberOfNodes(unsigned int);
00058 ValueType GetInputValue(unsigned int i);
00059 void SetInputValue(unsigned int i, ValueType value);
00060
00061 ValueType GetOutputValue(int);
00062 void SetOutputValue(int, ValueType);
00063
00064 ValuePointer GetOutputVector();
00065 void SetOutputVector(TVector value);
00066
00067 void ForwardPropagate();
00068 void ForwardPropagate(TVector input);
00069
00070 void BackwardPropagate();
00071 void BackwardPropagate(TOutput itkNotUsed(errors)){};
00072
00073 void SetOutputErrorValues(TOutput);
00074 ValueType GetOutputErrorValue(unsigned int node_id);
00075
00076
00077 ValueType GetInputErrorValue(int node_id);
00078 ValuePointer GetInputErrorVector();
00079 void SetInputErrorValue(ValueType, int node_id);
00080
00081 TVector GetCenter(int i);
00082 void SetCenter(TVector c,int i);
00083
00084 ValueType GetRadii(int i);
00085 void SetRadii(ValueType c,int i);
00086
00087
00088 ValueType Activation(ValueType);
00089 ValueType DActivation(ValueType);
00090
00091 void SetBias(ValueType b);
00092
00093 ValueType GetBias();
00094
00095 void SetDistanceMetric(DistanceMetricType* f);
00096 DistanceMetricPointer GetDistanceMetric(){return m_DistanceMetric;}
00097
00098 itkSetMacro(NumClasses, int);
00099 itkGetConstReferenceMacro(NumClasses,int);
00100
00101 void SetRBF(RBFType* f);
00102 itkGetObjectMacro(RBF, RBFType);
00103
00104 protected:
00105
00106 RBFLayer();
00107 ~RBFLayer();
00108
00110 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00111
00112 private:
00113
00114 typename DistanceMetricType::Pointer m_DistanceMetric;
00115 NodeVectorType m_NodeInputValues;
00116 NodeVectorType m_NodeOutputValues;
00117 NodeVectorType m_InputErrorValues;
00118 NodeVectorType m_OutputErrorValues;
00119 std::vector<TVector> m_Centers;
00120 NodeArrayType m_Radii;
00121 int m_NumClasses;
00122 ValueType m_Bias;
00123 int m_RBF_Dim;
00124 typename RBFType::Pointer m_RBF;
00125 };
00126
00127 }
00128 }
00129
00130 #ifndef ITK_MANUAL_INSTANTIATION
00131 #include "itkRBFLayer.txx"
00132 #endif
00133
00134 #endif