00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkLayerBase_h
00018 #define __itkLayerBase_h
00019
00020 #include <iostream>
00021 #include "itkLightProcessObject.h"
00022 #include "itkWeightSetBase.h"
00023 #include "itkArray.h"
00024 #include "itkVector.h"
00025 #include "itkTransferFunctionBase.h"
00026 #include "itkInputFunctionBase.h"
00027
00028 #include "itkMacro.h"
00029
00030 namespace itk
00031 {
00032 namespace Statistics
00033 {
00034
00035 template<class TVector, class TOutput>
00036 class LayerBase : public LightProcessObject
00037 {
00038
00039 public:
00040 typedef LayerBase Self;
00041 typedef LightProcessObject Superclass;
00042 typedef SmartPointer<Self> Pointer;
00043 typedef SmartPointer<const Self> ConstPointer;
00044
00046 itkTypeMacro(LayerBase, LightProcessObject);
00047
00048 typedef TVector InputVectorType;
00049 typedef TOutput OutputVectorType;
00050
00051 typedef typename TVector::ValueType ValueType;
00052 typedef ValueType* ValuePointer;
00053 typedef vnl_vector<ValueType> NodeVectorType;
00054
00055 typedef WeightSetBase<TVector,TOutput> WeightSetType;
00056
00057 typedef TransferFunctionBase<ValueType> TransferFunctionType;
00058
00059 typedef InputFunctionBase<ValueType*, ValueType> InputFunctionType;
00060
00061 typedef typename InputFunctionType::Pointer InputFunctionPointer;
00062
00063 typedef typename TransferFunctionType::Pointer TransferFunctionPointer;
00064
00065 virtual void SetNumberOfNodes(unsigned int);
00066 unsigned int GetNumberOfNodes();
00067
00068 virtual ValueType GetInputValue(unsigned int) = 0;
00069 virtual ValueType GetOutputValue(int) = 0;
00070 virtual ValuePointer GetOutputVector() = 0;
00071
00072 virtual void ForwardPropagate(){};
00073
00074 virtual void ForwardPropagate(TVector){};
00075
00076 virtual void BackwardPropagate(TOutput){};
00077
00078 virtual void BackwardPropagate(){};
00079 virtual ValueType GetOutputErrorValue(unsigned int) = 0;
00080 virtual void SetOutputErrorValues(TOutput) {};
00081
00082 virtual ValueType GetInputErrorValue(int) = 0;
00083 virtual ValuePointer GetInputErrorVector() = 0;
00084 virtual void SetInputErrorValue(ValueType, int) {};
00085
00086 itkSetObjectMacro(InputWeightSet, WeightSetType);
00087 itkGetObjectMacro(InputWeightSet, WeightSetType);
00088
00089 itkSetObjectMacro(OutputWeightSet, WeightSetType);
00090 itkGetObjectMacro(OutputWeightSet, WeightSetType);
00091
00092 void SetNodeInputFunction(InputFunctionType* f);
00093 itkGetObjectMacro(NodeInputFunction, InputFunctionType);
00094
00095 void SetTransferFunction(TransferFunctionType* f);
00096 itkGetObjectMacro(ActivationFunction, TransferFunctionType);
00097
00098 virtual ValueType Activation(ValueType) = 0;
00099 virtual ValueType DActivation(ValueType) = 0;
00100
00101 itkSetMacro(LayerType, unsigned int);
00102 itkGetMacro(LayerType, unsigned int);
00103
00104 virtual void SetBias(ValueType) = 0;
00105 virtual ValueType GetBias() = 0;
00106
00107 protected:
00108
00109 LayerBase();
00110 ~LayerBase();
00111
00113 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00114
00115 unsigned int m_LayerType;
00116 unsigned int m_NumberOfNodes;
00117
00118 typename WeightSetType::Pointer m_InputWeightSet;
00119 typename WeightSetType::Pointer m_OutputWeightSet;
00120
00121 TransferFunctionPointer m_ActivationFunction;
00122 InputFunctionPointer m_NodeInputFunction;
00123
00124 };
00125
00126 }
00127 }
00128
00129 #ifndef ITK_MANUAL_INSTANTIATION
00130 #include "itkLayerBase.txx"
00131 #endif
00132
00133 #endif