00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedStdBackprop.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00008
00009 #include <cmath>
00010 #include <iomanip>
00011 #include <cassert>
00012
00013 using namespace std;
00014
00015
00016 SprTrainedStdBackprop::SprTrainedStdBackprop()
00017 :
00018 SprAbsTrainedClassifier()
00019 , nNodes_(0)
00020 , nLinks_(0)
00021 , structure_()
00022 , nodeType_()
00023 , nodeActFun_()
00024 , nodeNInputLinks_()
00025 , nodeFirstInputLink_()
00026 , linkSource_()
00027 , nodeBias_()
00028 , linkWeight_()
00029 {
00030 this->setCut(SprUtils::lowerBound(0.5));
00031 }
00032
00033
00034 SprTrainedStdBackprop::SprTrainedStdBackprop(
00035 const char* structure,
00036 const std::vector<SprNNDefs::NodeType>& nodeType,
00037 const std::vector<SprNNDefs::ActFun>& nodeActFun,
00038 const std::vector<int>& nodeNInputLinks,
00039 const std::vector<int>& nodeFirstInputLink,
00040 const std::vector<int>& linkSource,
00041 const std::vector<double>& nodeBias,
00042 const std::vector<double>& linkWeight)
00043 :
00044 SprAbsTrainedClassifier(),
00045 nNodes_(0),
00046 nLinks_(0),
00047 structure_(structure),
00048 nodeType_(nodeType),
00049 nodeActFun_(nodeActFun),
00050 nodeNInputLinks_(nodeNInputLinks),
00051 nodeFirstInputLink_(nodeFirstInputLink),
00052 linkSource_(linkSource),
00053 nodeBias_(nodeBias),
00054 linkWeight_(linkWeight)
00055 {
00056 nNodes_ = nodeType_.size();
00057 assert( nNodes_ == nodeActFun_.size() );
00058 assert( nNodes_ == nodeNInputLinks_.size() );
00059 assert( nNodes_ == nodeFirstInputLink_.size() );
00060 assert( nNodes_ == nodeBias_.size() );
00061 nLinks_ = linkSource_.size();
00062 assert( nLinks_ == linkWeight_.size() );
00063 this->setCut(SprUtils::lowerBound(0.5));
00064 }
00065
00066
00067 SprTrainedStdBackprop::SprTrainedStdBackprop(
00068 const SprTrainedStdBackprop& other)
00069 :
00070 SprAbsTrainedClassifier(other)
00071 , nNodes_(other.nNodes_)
00072 , nLinks_(other.nLinks_)
00073 , structure_(other.structure_)
00074 , nodeType_(other.nodeType_)
00075 , nodeActFun_(other.nodeActFun_)
00076 , nodeNInputLinks_(other.nodeNInputLinks_)
00077 , nodeFirstInputLink_(other.nodeFirstInputLink_)
00078 , linkSource_(other.linkSource_)
00079 , nodeBias_(other.nodeBias_)
00080 , linkWeight_(other.linkWeight_)
00081 {}
00082
00083
00084 double SprTrainedStdBackprop::activate(double x, SprNNDefs::ActFun f) const
00085 {
00086 switch (f)
00087 {
00088 case SprNNDefs::ID :
00089 return x;
00090 break;
00091 case SprNNDefs::LOGISTIC :
00092 return SprTransformation::logit(x);
00093 break;
00094 default :
00095 cerr << "FATAL ERROR: Unknown activation function "
00096 << f << " in SprTrainedStdBackprop::activate" << endl;
00097 return 0;
00098 }
00099 return 0;
00100 }
00101
00102
00103 void SprTrainedStdBackprop::print(std::ostream& os) const
00104 {
00105 os << "Trained StdBackprop with configuration "
00106 << structure_.c_str() << " " << SprVersion << endl;
00107 os << "Activation functions: Identity=1, Logistic=2" << endl;
00108 os << "Cut: " << cut_.size();
00109 for( int i=0;i<cut_.size();i++ )
00110 os << " " << cut_[i].first << " " << cut_[i].second;
00111 os << endl;
00112 os << "Nodes: " << nNodes_ << endl;
00113 for( int i=0;i<nNodes_;i++ ) {
00114 char nodeType;
00115 switch( nodeType_[i] )
00116 {
00117 case SprNNDefs::INPUT :
00118 nodeType = 'I';
00119 break;
00120 case SprNNDefs::HIDDEN :
00121 nodeType = 'H';
00122 break;
00123 case SprNNDefs::OUTPUT :
00124 nodeType = 'O';
00125 break;
00126 }
00127 int actFun = 0;
00128 switch( nodeActFun_[i] )
00129 {
00130 case SprNNDefs::ID :
00131 actFun = 1;
00132 break;
00133 case SprNNDefs::LOGISTIC :
00134 actFun = 2;
00135 break;
00136 }
00137 os << setw(6) << i
00138 << " Type: " << nodeType
00139 << " ActFunction: " << actFun
00140 << " NInputLinks: " << setw(6) << nodeNInputLinks_[i]
00141 << " FirstInputLink: " << setw(6) << nodeFirstInputLink_[i]
00142 << " Bias: " << nodeBias_[i]
00143 << endl;
00144 }
00145 os << "Links: " << nLinks_ << endl;
00146 for( int i=0;i<nLinks_;i++ ) {
00147 os << setw(6) << i
00148 << " Source: " << setw(6) << linkSource_[i]
00149 << " Weight: " << linkWeight_[i]
00150 << endl;
00151 }
00152 }
00153
00154
00155 double SprTrainedStdBackprop::response(const std::vector<double>& v) const
00156 {
00157
00158 vector<double> nodeOut(nNodes_,0);
00159 int d = 0;
00160 for( int i=0;i<nNodes_;i++ ) {
00161 if( nodeType_[i] == SprNNDefs::INPUT ) {
00162 assert( d < v.size() );
00163 nodeOut[i] = v[d++];
00164 }
00165 else
00166 break;
00167 }
00168 assert( d == v.size() );
00169
00170
00171 for( int i=0;i<nNodes_;i++ ) {
00172 double nodeAct = 0;
00173 if( nodeNInputLinks_[i] > 0 ) {
00174 for( int j=nodeFirstInputLink_[i];
00175 j<nodeFirstInputLink_[i]+nodeNInputLinks_[i];j++ ) {
00176 nodeAct += nodeOut[linkSource_[j]] * linkWeight_[j];
00177 }
00178 nodeOut[i] = this->activate(nodeAct+nodeBias_[i],nodeActFun_[i]);
00179 }
00180 }
00181
00182
00183 return nodeOut[nNodes_-1];
00184 }