00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #ifndef __itkTrainingFunctionBase_h
00019 #define __itkTrainingFunctionBase_h
00020
00021 #include <iostream>
00022 #include "itkLightProcessObject.h"
00023 #include "itkNeuralNetworkObject.h"
00024 #include "itkSquaredDifferenceErrorFunction.h"
00025 #include "itkMeanSquaredErrorFunction.h"
00026 namespace itk
00027 {
00028 namespace Statistics
00029 {
00030
00031 template<class TSample, class TOutput, class ScalarType>
00032 class TrainingFunctionBase : public LightProcessObject
00033 {
00034 public:
00035 typedef TrainingFunctionBase Self;
00036 typedef LightProcessObject Superclass;
00037 typedef SmartPointer<Self> Pointer;
00038 typedef SmartPointer<const Self> ConstPointer;
00039
00041 itkTypeMacro(TrainingFunctionBase, LightProcessObject);
00042
00044 itkNewMacro(Self);
00045
00046 typedef typename TSample::MeasurementVectorType VectorType;
00047 typedef typename TOutput::MeasurementVectorType OutputVectorType;
00048
00049 typedef std::vector<VectorType> InputSampleVectorType;
00050 typedef std::vector<OutputVectorType> OutputSampleVectorType;
00051 typedef ScalarType ValueType;
00052 typedef NeuralNetworkObject<VectorType, OutputVectorType> NetworkType;
00053 typedef ErrorFunctionBase<OutputVectorType, ScalarType> PerformanceFunctionType;
00054 typedef SquaredDifferenceErrorFunction<OutputVectorType, ScalarType> DefaultPerformanceType;
00055
00056 void SetTrainingSamples(TSample* samples);
00057 void SetTargetValues(TOutput* targets);
00058 void SetLearningRate(ValueType);
00059
00060 ValueType GetLearningRate();
00061
00062 itkSetMacro(Iterations, long);
00063 itkGetConstReferenceMacro(Iterations, long);
00064
00065 void SetPerformanceFunction(PerformanceFunctionType* f);
00066
00067 virtual void
00068 Train(NetworkType* itkNotUsed(net), TSample* itkNotUsed(samples), TOutput* itkNotUsed(targets))
00069 {
00070
00071 };
00072
00073 inline VectorType
00074 defaultconverter(typename TSample::MeasurementVectorType v)
00075 {
00076 VectorType temp;
00077 for (unsigned int i = 0; i < v.Size(); i++)
00078 {
00079 temp[i] = static_cast<ScalarType>(v[i]) ;
00080 }
00081 return temp;
00082 }
00083
00084 inline OutputVectorType
00085 targetconverter(typename TOutput::MeasurementVectorType v)
00086 {
00087 OutputVectorType temp;
00088 for (unsigned int i = 0; i < v.Size(); i++)
00089 {
00090 temp[i] = static_cast<ScalarType>(v[i]) ;
00091 }
00092 return temp;
00093 }
00094
00095 protected:
00096
00097 TrainingFunctionBase();
00098 ~TrainingFunctionBase(){};
00099
00101 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00102
00103 TSample* m_TrainingSamples;
00104 TOutput* m_SampleTargets;
00105 InputSampleVectorType m_InputSamples;
00106 OutputSampleVectorType m_Targets;
00107 long m_Iterations;
00108 ValueType m_LearningRate;
00109 typename PerformanceFunctionType::Pointer m_PerformanceFunction;
00110 };
00111
00112 }
00113 }
00114 #ifndef ITK_MANUAL_INSTANTIATION
00115 #include "itkTrainingFunctionBase.txx"
00116 #endif
00117
00118 #endif