00001 /*========================================================================= 00002 00003 Program: Insight Segmentation & Registration Toolkit 00004 Module: $RCSfile: itkWeightSetBase.h,v $ 00005 Language: C++ 00006 Date: $Date: 2005/08/10 20:36:01 $ 00007 Version: $Revision: 1.2 $ 00008 00009 Copyright (c) Insight Software Consortium. All rights reserved. 00010 See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details. 00011 00012 This software is distributed WITHOUT ANY WARRANTY; without even 00013 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 00014 PURPOSE. See the above copyright notices for more information. 00015 00016 =========================================================================*/ 00017 00018 #ifndef __itkWeightSetBase_h 00019 #define __itkWeightSetBase_h 00020 00021 #include "itkLayerBase.h" 00022 #include "itkLightProcessObject.h" 00023 #include <vnl/vnl_matrix.h> 00024 #include <vnl/vnl_diag_matrix.h> 00025 #include "itkMacro.h" 00026 #include "itkVector.h" 00027 #include "itkMersenneTwisterRandomVariateGenerator.h" 00028 #include <math.h> 00029 #include <stdlib.h> 00030 00031 namespace itk 00032 { 00033 namespace Statistics 00034 { 00035 00036 template<class TVector, class TOutput> 00037 class WeightSetBase : public LightProcessObject 00038 { 00039 public: 00040 00041 typedef WeightSetBase Self; 00042 typedef LightProcessObject Superclass; 00043 typedef SmartPointer<Self> Pointer; 00044 typedef SmartPointer<const Self> ConstPointer; 00045 itkTypeMacro(WeightSetBase, LightProcessObject); 00046 00047 typedef MersenneTwisterRandomVariateGenerator RandomVariateGeneratorType; 00048 00049 typedef typename TVector::ValueType ValueType; 00050 typedef ValueType* ValuePointer; 00051 00052 void Initialize(); 00053 00054 ValueType RandomWeightValue(ValueType low, ValueType high); 00055 00056 void ForwardPropagate(ValuePointer inputlayeroutputvalues); 00057 00058 void BackwardPropagate(ValuePointer inputerror); 00059 00060 void SetConnectivityMatrix(vnl_matrix < int>); 00061 00062 void SetNumberOfInputNodes(int n); 00063 int GetNumberOfInputNodes(); 00064 00065 void SetNumberOfOutputNodes(int n); 00066 int GetNumberOfOutputNodes(); 00067 00068 void SetRange(ValueType Range); 00069 00070 ValuePointer GetOutputValues(); 00071 00072 ValuePointer GetInputValues(); 00073 00074 ValuePointer GetTotalDeltaValues(); 00075 00076 ValuePointer GetTotalDeltaBValues(); 00077 00078 ValuePointer GetDeltaValues(); 00079 00080 void SetDeltaValues(ValuePointer); 00081 00082 void SetDWValues(ValuePointer); 00083 00084 void SetDBValues(ValuePointer); 00085 00086 ValuePointer GetDeltaBValues(); 00087 00088 void SetDeltaBValues(ValuePointer); 00089 00090 ValuePointer GetDWValues(); 00091 00092 ValuePointer GetPrevDWValues(); 00093 00094 ValuePointer GetPrevDBValues(); 00095 00096 ValuePointer GetPrev_m_2DWValues(); 00097 00098 ValuePointer GetPrevDeltaValues(); 00099 00100 ValuePointer GetPrev_m_2DeltaValues(); 00101 00102 ValuePointer GetPrevDeltaBValues(); 00103 00104 ValuePointer GetWeightValues(); 00105 00106 void SetWeightValues(ValuePointer weights); 00107 00108 void UpdateWeights(ValueType LearningRate); 00109 00110 void SetMomentum(ValueType); 00111 00112 ValueType GetMomentum(); 00113 00114 void SetBias(ValueType); 00115 00116 ValueType GetBias(); 00117 00118 bool GetFirstPass(); 00119 00120 void SetFirstPass(bool); 00121 00122 bool GetSecondPass(); 00123 00124 void SetSecondPass(bool); 00125 00126 void InitializeWeights(); 00127 00128 protected: 00129 00130 WeightSetBase(); 00131 ~WeightSetBase(); 00132 00134 virtual void PrintSelf( std::ostream& os, Indent indent ) const; 00135 00136 typename RandomVariateGeneratorType::Pointer m_RandomGenerator; 00137 int m_NumberOfInputNodes; 00138 int m_NumberOfOutputNodes; 00139 vnl_matrix<ValueType> m_OutputValues; 00140 vnl_matrix<ValueType> m_InputErrorValues; 00141 00142 // weight updates dw=lr * del *y 00143 // DW= current 00144 // DW_m_1 = previous 00145 // DW_m_2= second to last 00146 // same applies for delta and bias values 00147 00148 vnl_matrix<ValueType> m_DW; // delta valies for weight update 00149 vnl_matrix<ValueType> m_DW_new; // delta valies for weight update 00150 vnl_matrix<ValueType> m_DW_m_1; // delta valies for weight update 00151 vnl_matrix<ValueType> m_DW_m_2; // delta valies for weight update 00152 vnl_matrix<ValueType> m_DW_m; // delta valies for weight update 00153 00154 vnl_vector<ValueType> m_DB; // delta values for bias update 00155 vnl_vector<ValueType> m_DB_new; // delta values for bias update 00156 vnl_vector<ValueType> m_DB_m_1; // delta values for bias update 00157 vnl_vector<ValueType> m_DB_m_2; // delta values for bias update 00158 00159 vnl_matrix<ValueType> m_del; // dw=lr * del * y 00160 vnl_matrix<ValueType> m_del_new; // dw=lr * del * y 00161 vnl_matrix<ValueType> m_del_m_1; // dw=lr * del * y 00162 vnl_matrix<ValueType> m_del_m_2; // dw=lr * del * y 00163 00164 vnl_vector<ValueType> m_delb; // delta values for bias update 00165 vnl_vector<ValueType> m_delb_new; // delta values for bias update 00166 vnl_vector<ValueType> m_delb_m_1; // delta values for bias update 00167 vnl_vector<ValueType> m_delb_m_2; // delta values for bias update 00168 00169 vnl_matrix<ValueType> m_InputLayerOutput; 00170 vnl_matrix<ValueType> m_WeightMatrix; // composed of weights and a column of biases 00171 vnl_matrix<int> m_ConnectivityMatrix; 00172 00173 ValueType m_Momentum; 00174 ValueType m_Bias; 00175 bool m_FirstPass; 00176 bool m_SecondPass; 00177 ValueType m_Range; 00178 }; //class 00179 00180 } // end namespace Statistics 00181 } // end namespace itk 00182 00183 #ifndef ITK_MANUAL_INSTANTIATION 00184 #include "itkWeightSetBase.txx" 00185 #endif 00186 00187 00188 #endif