CMS 3D CMS Logo

SprTrainedStdBackprop.cc

Go to the documentation of this file.
00001 //$Id: SprTrainedStdBackprop.cc,v 1.2 2007/09/21 22:32:10 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedStdBackprop.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00008 
00009 #include <cmath>
00010 #include <iomanip>
00011 #include <cassert>
00012 
00013 using namespace std;
00014 
00015 
00016 SprTrainedStdBackprop::SprTrainedStdBackprop()
00017   : 
00018   SprAbsTrainedClassifier()
00019   , nNodes_(0)
00020   , nLinks_(0)
00021   , structure_()
00022   , nodeType_()
00023   , nodeActFun_()
00024   , nodeNInputLinks_()
00025   , nodeFirstInputLink_()
00026   , linkSource_()
00027   , nodeBias_()
00028   , linkWeight_()
00029 {
00030   this->setCut(SprUtils::lowerBound(0.5));
00031 }
00032 
00033 
00034 SprTrainedStdBackprop::SprTrainedStdBackprop(
00035                         const char* structure,
00036                         const std::vector<SprNNDefs::NodeType>& nodeType,
00037                         const std::vector<SprNNDefs::ActFun>& nodeActFun,
00038                         const std::vector<int>& nodeNInputLinks,
00039                         const std::vector<int>& nodeFirstInputLink,
00040                         const std::vector<int>& linkSource,
00041                         const std::vector<double>& nodeBias,
00042                         const std::vector<double>& linkWeight)
00043   :
00044   SprAbsTrainedClassifier(),
00045   nNodes_(0),
00046   nLinks_(0),
00047   structure_(structure),
00048   nodeType_(nodeType),
00049   nodeActFun_(nodeActFun),
00050   nodeNInputLinks_(nodeNInputLinks),
00051   nodeFirstInputLink_(nodeFirstInputLink),
00052   linkSource_(linkSource),
00053   nodeBias_(nodeBias),
00054   linkWeight_(linkWeight)
00055 {
00056   nNodes_ = nodeType_.size();
00057   assert( nNodes_ == nodeActFun_.size() );
00058   assert( nNodes_ == nodeNInputLinks_.size() );
00059   assert( nNodes_ == nodeFirstInputLink_.size() );
00060   assert( nNodes_ == nodeBias_.size() );
00061   nLinks_ = linkSource_.size();
00062   assert( nLinks_ == linkWeight_.size() );
00063   this->setCut(SprUtils::lowerBound(0.5));
00064 }
00065 
00066 
00067 SprTrainedStdBackprop::SprTrainedStdBackprop(
00068                           const SprTrainedStdBackprop& other)
00069   :
00070   SprAbsTrainedClassifier(other)
00071   , nNodes_(other.nNodes_)
00072   , nLinks_(other.nLinks_)
00073   , structure_(other.structure_)
00074   , nodeType_(other.nodeType_)
00075   , nodeActFun_(other.nodeActFun_)
00076   , nodeNInputLinks_(other.nodeNInputLinks_)
00077   , nodeFirstInputLink_(other.nodeFirstInputLink_)
00078   , linkSource_(other.linkSource_)
00079   , nodeBias_(other.nodeBias_)
00080   , linkWeight_(other.linkWeight_)
00081 {}
00082 
00083 
00084 double SprTrainedStdBackprop::activate(double x, SprNNDefs::ActFun f) const 
00085 {
00086   switch (f) 
00087     {
00088     case SprNNDefs::ID :
00089       return x;
00090       break;
00091     case SprNNDefs::LOGISTIC :
00092       return SprTransformation::logit(x);
00093       break;
00094     default :
00095       cerr << "FATAL ERROR: Unknown activation function " 
00096            << f << " in SprTrainedStdBackprop::activate" << endl;
00097       return 0;
00098     }
00099   return 0;
00100 }
00101 
00102 
00103 void SprTrainedStdBackprop::print(std::ostream& os) const 
00104 {
00105   os << "Trained StdBackprop with configuration " 
00106      << structure_.c_str() << " " << SprVersion << endl; 
00107   os << "Activation functions: Identity=1, Logistic=2" << endl;
00108   os << "Cut: " << cut_.size();
00109   for( int i=0;i<cut_.size();i++ )
00110     os << "      " << cut_[i].first << " " << cut_[i].second;
00111   os << endl;
00112   os << "Nodes: " << nNodes_ << endl;
00113   for( int i=0;i<nNodes_;i++ ) {
00114     char nodeType;
00115     switch( nodeType_[i] )
00116       {
00117       case SprNNDefs::INPUT :
00118         nodeType = 'I';
00119         break;
00120       case SprNNDefs::HIDDEN :
00121         nodeType = 'H';
00122         break;
00123       case SprNNDefs::OUTPUT :
00124         nodeType = 'O';
00125         break;
00126       }
00127     int actFun = 0;
00128     switch( nodeActFun_[i] )
00129       {
00130       case SprNNDefs::ID :
00131         actFun = 1;
00132         break;
00133       case SprNNDefs::LOGISTIC :
00134         actFun = 2;
00135         break;
00136       }
00137     os << setw(6) << i
00138        << "    Type: "           << nodeType
00139        << "    ActFunction: "    << actFun
00140        << "    NInputLinks: "    << setw(6) << nodeNInputLinks_[i]
00141        << "    FirstInputLink: " << setw(6) << nodeFirstInputLink_[i]
00142        << "    Bias: "           << nodeBias_[i]
00143        << endl;
00144   }
00145   os << "Links: " << nLinks_ << endl;
00146   for( int i=0;i<nLinks_;i++ ) {
00147     os << setw(6) << i
00148        << "    Source: " << setw(6) << linkSource_[i]
00149        << "    Weight: " << linkWeight_[i]
00150        << endl;
00151   }
00152 }
00153 
00154 
00155 double SprTrainedStdBackprop::response(const std::vector<double>& v) const
00156 {
00157   // Initialize and process input nodes
00158   vector<double> nodeOut(nNodes_,0);
00159   int d = 0;
00160   for( int i=0;i<nNodes_;i++ ) {
00161     if( nodeType_[i] == SprNNDefs::INPUT ) {
00162       assert( d < v.size() );
00163       nodeOut[i] = v[d++];
00164     }
00165     else
00166       break;
00167   }
00168   assert( d == v.size() );
00169 
00170   // Process hidden and output nodes
00171   for( int i=0;i<nNodes_;i++ ) {
00172     double nodeAct = 0;
00173     if( nodeNInputLinks_[i] > 0 ) {
00174       for( int j=nodeFirstInputLink_[i];
00175            j<nodeFirstInputLink_[i]+nodeNInputLinks_[i];j++ ) {
00176         nodeAct += nodeOut[linkSource_[j]] * linkWeight_[j];
00177       }
00178       nodeOut[i] = this->activate(nodeAct+nodeBias_[i],nodeActFun_[i]);
00179     }
00180   }
00181 
00182   // Find output node and return result
00183   return nodeOut[nNodes_-1];
00184 }

Generated on Tue Jun 9 17:42:04 2009 for CMSSW by  doxygen 1.5.4