00001 /*========================================================================= 00002 00003 Program: Insight Segmentation & Registration Toolkit 00004 Module: $RCSfile: itkMultilayerNeuralNetworkBase.h,v $ 00005 Language: C++ 00006 Date: $Date: 2005/11/21 18:39:31 $ 00007 Version: $Revision: 1.2 $ 00008 00009 Copyright (c) Insight Software Consortium. All rights reserved. 00010 See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details. 00011 00012 This software is distributed WITHOUT ANY WARRANTY; without even 00013 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 00014 PURPOSE. See the above copyright notices for more information. 00015 00016 =========================================================================*/ 00017 #ifndef __MultiLayerNeuralNetworkBase_h 00018 #define __MultiLayerNeuralNetworkBase_h 00019 00020 #include "itkNeuralNetworkObject.h" 00021 #include "itkErrorBackPropagationLearningFunctionBase.h" 00022 #include "itkErrorBackPropagationLearningWithMomentum.h" 00023 #include "itkQuickPropLearningRule.h" 00024 00025 namespace itk 00026 { 00027 namespace Statistics 00028 { 00029 00030 template<class TVector, class TOutput> 00031 class MultilayerNeuralNetworkBase : public NeuralNetworkObject<TVector, TOutput> 00032 { 00033 public: 00034 00035 typedef MultilayerNeuralNetworkBase Self; 00036 typedef NeuralNetworkObject<TVector, TOutput> Superclass; 00037 typedef SmartPointer<Self> Pointer; 00038 typedef SmartPointer<const Self> ConstPointer; 00039 itkTypeMacro(MultilayerNeuralNetworkBase, NeuralNetworkObject); 00040 00042 itkNewMacro( Self ); 00043 00044 typedef typename Superclass::ValueType ValueType; 00045 00046 typedef typename Superclass::LayerType LayerType; 00047 typedef typename Superclass::WeightSetType WeightSetType; 00048 typedef typename Superclass::WeightSetPointer WeightSetPointer; 00049 typedef typename Superclass::LayerPointer LayerPointer; 00050 typedef typename Superclass::LearningFunctionType LearningFunctionType; 00051 typedef typename Superclass::LearningFunctionPointer LearningFunctionPointer; 00052 00053 typedef std::vector<WeightSetPointer> WeightVectorType; 00054 typedef std::vector<LayerPointer> LayerVectorType; 00055 00056 itkSetMacro(NumOfLayers, int); 00057 itkGetConstReferenceMacro(NumOfLayers, int); 00058 00059 void AddLayer(LayerType*); 00060 00061 void AddWeightSet(WeightSetType*); 00062 00063 void SetLearningFunction(LearningFunctionType* f); 00064 00065 virtual ValueType* GenerateOutput(TVector samplevector); 00066 00067 virtual void BackwardPropagate(TOutput errors); 00068 00069 virtual void UpdateWeights(ValueType); 00070 00071 void SetLearningRule(LearningFunctionType*); 00072 00073 void SetLearningRate(ValueType learningrate); 00074 00075 void InitializeWeights(); 00076 00077 protected: 00078 MultilayerNeuralNetworkBase(); 00079 ~MultilayerNeuralNetworkBase(); 00080 00081 LayerVectorType m_Layers; 00082 WeightVectorType m_Weights; 00083 LearningFunctionPointer m_LearningFunction; 00084 ValueType m_LearningRate; 00085 int m_NumOfLayers; 00086 00088 virtual void PrintSelf( std::ostream& os, Indent indent ) const; 00089 }; 00090 00091 } // end namespace Statistics 00092 } // end namespace itk 00093 00094 #ifndef ITK_MANUAL_INSTANTIATION 00095 #include "itkMultilayerNeuralNetworkBase.txx" 00096 #endif 00097 00098 #endif