CMS 3D CMS Logo

SprTrainedRBF.cc

Go to the documentation of this file.
00001 //$Id: SprTrainedRBF.cc,v 1.2 2007/09/21 22:32:10 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedRBF.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00006 
00007 #include <map>
00008 #include <utility>
00009 #include <cmath>
00010 #include <fstream>
00011 #include <sstream>
00012 #include <cassert>
00013 
00014 using namespace std;
00015 
00016 
00017 bool SprTrainedRBF::readNet(const char* netfile)
00018 {
00019   // open file
00020   string fname = netfile;
00021   ifstream file(fname.c_str());
00022   if( !file ) {
00023     cerr << "Unable to open file " << fname.c_str() << endl;
00024     return false;
00025   }
00026  
00027   // read junk on top of the file
00028   string line;
00029   unsigned nline = 1;
00030   int nempty = 5;
00031   for( int i=0;i<nempty;i++ ) {
00032     if( !getline(file,line) ) {
00033       cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00034       return false;
00035     }
00036     nline++;
00037   }
00038 
00039   // read number of nodes and links
00040   unsigned nnodes(0), nlinks(0);
00041   if( !getline(file,line) ) {
00042     cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00043     return false;
00044   }
00045   else {
00046     nline++;
00047     line.erase( 0, line.find_first_of(':')+1 );
00048     istringstream ist(line);
00049     ist >> nnodes;
00050     if( nnodes == 0 ) {
00051       cerr << "No nodes found in " << fname.c_str() << endl;
00052       return false;
00053     }
00054   }
00055   if( !getline(file,line) ) {
00056     cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00057     return false;
00058   }
00059   else {
00060     nline++;
00061     line.erase( 0, line.find_first_of(':')+1 );
00062     istringstream ist(line);
00063     ist >> nlinks;
00064     if( nlinks == 0 ) {
00065       cerr << "No links found in " << fname.c_str() << endl;
00066       return false;
00067     }
00068   }
00069 
00070   // more empty lines
00071   nempty = 4;
00072   for( int i=0;i<nempty;i++ ) {
00073     if( !getline(file,line) ) {
00074       cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00075       return false;
00076     }
00077     nline++;
00078   }
00079 
00080   // read learning and update function
00081   if( !getline(file,line) ) {
00082     cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00083     return false;
00084   }
00085   else {
00086     nline++;
00087     line.erase( 0, line.find_first_of(':')+1 );
00088     line.erase( line.find_last_not_of(' ')+1 );
00089     line.erase( 0, line.find_first_not_of(' ') );
00090     if( line != "RadialBasisLearning" ) {
00091       cerr << "Learning function is not RadialBasisLearning!!!" << endl;
00092       return false;
00093     }
00094   }
00095   if( !getline(file,line) ) {
00096     cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00097     return false;
00098   }
00099   else {
00100     nline++;
00101     line.erase( 0, line.find_first_of(':')+1 );
00102     line.erase( line.find_last_not_of(' ')+1 );
00103     line.erase( 0, line.find_first_not_of(' ') );
00104     if( line != "Topological_Order" ) {
00105       cerr << "Update function is not Topological_Order!!!" << endl;
00106       return false;
00107     }
00108   }
00109 
00110   // more empty lines
00111   nempty = 6;
00112   for( int i=0;i<nempty;i++ ) {
00113     if( !getline(file,line) ) {
00114       cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00115       return false;
00116     }
00117     nline++;
00118   }
00119 
00120   // read activation and output functions (6th and 7th fields)
00121   SprNNDefs::ActFun baseAct;
00122   SprNNDefs::OutFun baseOut;
00123   if( !getline(file,line) ) {
00124     cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00125     return false;
00126   }
00127   else {
00128     nline++;
00129     for( int i=0;i<5;i++ )
00130       line.erase( 0, line.find_first_of('|')+1 );
00131     string act = line.substr( 0, line.find_first_of('|') );
00132     string out = line.substr( line.find_first_of('|')+1 );
00133     act.erase( 0, act.find_first_not_of(' ') );
00134     act.erase( act.find_last_not_of(' ')+1 );
00135     out.erase( 0, out.find_first_not_of(' ') );
00136     out.erase( out.find_last_not_of(' ')+1 );
00137     if(      act == "Act_Logistic" )
00138       baseAct = SprNNDefs::LOGISTIC;
00139     else if( act=="Act_Identity" || act=="Act_IdentityPlusBias" )
00140       baseAct = SprNNDefs::ID;
00141     else {
00142       cerr << "Unknown activation function " << act.c_str() 
00143            << " in " << fname.c_str() << endl;
00144       return false;
00145     }
00146     if( out == "Out_Identity" )
00147       baseOut = SprNNDefs::OUTID;
00148     else {
00149       cerr << "Unknown output function " << out.c_str() 
00150            << " in " << fname.c_str() << endl;
00151       return false;
00152     }
00153   }
00154 
00155   // more empty lines
00156   nempty = 7;
00157   for( int i=0;i<nempty;i++ ) {
00158     if( !getline(file,line) ) {
00159       cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00160       return false;
00161     }
00162     nline++;
00163   }
00164 
00165   // read units section
00166   for( int i=0;i<nnodes;i++ ) {
00167     //    cout << "Reading node " << (i+1) << endl;
00168     if( !getline(file,line) ) {
00169       cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00170       return false;
00171     }
00172     else {
00173       nline++;
00174       // read index
00175       string piece = line.substr( 0, line.find_first_of('|') );
00176       line.erase( 0, line.find_first_of('|')+1 );
00177       istringstream stindex(piece);
00178       unsigned index(0);
00179       stindex >> index;
00180       assert( index != 0 );
00181       // read activation
00182       for( int j=0;j<2;j++ )
00183         line.erase( 0, line.find_first_of('|')+1 );
00184       piece = line.substr( 0, line.find_first_of('|') );
00185       line.erase( 0, line.find_first_of('|')+1 );
00186       istringstream stact(piece);
00187       double act = 0;
00188       stact >> act;
00189       // read bias
00190       piece = line.substr( 0, line.find_first_of('|') );
00191       line.erase( 0, line.find_first_of('|')+1 );
00192       istringstream stbias(piece);
00193       double bias = 0;
00194       stbias >> bias;
00195       // read node type
00196       piece = line.substr( 0, line.find_first_of('|') );
00197       piece.erase( 0, piece.find_first_not_of(' ') );
00198       piece.erase( piece.find_last_not_of(' ')+1 );
00199       line.erase( 0, line.find_first_of('|')+1 );
00200       SprNNDefs::NodeType type;
00201       if(      piece == "i" )
00202         type = SprNNDefs::INPUT;
00203       else if( piece == "h" )
00204         type = SprNNDefs::HIDDEN;
00205       else if( piece == "o" )
00206         type = SprNNDefs::OUTPUT;
00207       else {
00208         cerr << "Unknown node type " << piece.c_str() 
00209              << " in " << fname.c_str() << endl;
00210         return false;
00211       }
00212       // read activation function
00213       line.erase( 0, line.find_first_of('|')+1 );
00214       piece = line.substr( 0, line.find_first_of('|') );
00215       piece.erase( 0, piece.find_first_not_of(' ') );
00216       piece.erase( piece.find_last_not_of(' ')+1 );
00217       line.erase( 0, line.find_first_of('|')+1 );
00218       SprNNDefs::ActFun actfun = baseAct;
00219       ActRBF actrbf = Gauss;
00220       if( type == SprNNDefs::HIDDEN ) {
00221         if(      piece == "Act_RBF_Gaussian" )
00222           actrbf = Gauss;
00223         else if( piece == "Act_RBF_MultiQuadratic" )
00224           actrbf = MultiQ;
00225         else if( piece == "Act_RBF_ThinPlateSpline" )
00226           actrbf = ThinPlate;
00227         else {
00228           cerr << "Unknown RBF activation function " << piece.c_str() 
00229                << " in " << fname.c_str() << endl;
00230           return false;
00231         }
00232       }
00233       else {// not a hidden node
00234         if( !piece.empty() ) {
00235           if(      piece == "Act_Logistic" )
00236             actfun = SprNNDefs::LOGISTIC;
00237           else if( piece=="Act_Identity" || piece=="Act_IdentityPlusBias" )
00238             actfun = SprNNDefs::ID;
00239           else {
00240             cerr << "Unknown activation function " << piece.c_str() 
00241                  << " in " << fname.c_str() << endl;
00242             return false;
00243           }
00244         }
00245       }
00246       // read output function
00247       piece = line.substr( 0, line.find_first_of('|') );
00248       piece.erase( 0, piece.find_first_not_of(' ') );
00249       piece.erase( piece.find_last_not_of(' ')+1 );
00250       line.erase( 0, line.find_first_of('|')+1 );
00251       SprNNDefs::OutFun outfun = baseOut;
00252       if( !piece.empty() ) {
00253         if( piece == "Out_Identity" )
00254           outfun = SprNNDefs::OUTID;
00255         else {
00256           cerr << "Unknown output function " << piece.c_str() 
00257                << " in " << fname.c_str() << endl;
00258           return false;
00259         }
00260       }
00261       // make a node
00262       Node* node = new Node();
00263       node->index_ = index;
00264       node->type_ = type;
00265       node->actFun_ = actfun;
00266       node->outFun_ = outfun;
00267       node->actRBF_ = actrbf;
00268       node->bias_ = bias;
00269       node->act_ = act;
00270       nodes_.push_back(node);
00271       assert( index == nodes_.size() ); 
00272     }
00273   }
00274 
00275   // more empty lines
00276   nempty = 7;
00277   for( int i=0;i<nempty;i++ ) {
00278     if( !getline(file,line) ) {
00279       cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00280       return false;
00281     }
00282     nline++;
00283   }
00284 
00285   // read links
00286   unsigned readlinks = 0;
00287   while( readlinks < nlinks ) {
00288     //    cout << "Reading link " << (readlinks+1) << endl;
00289     if( !getline(file,line) ) {
00290       cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00291       return false;
00292     }
00293     else {
00294       nline++;
00295       // read index
00296       string piece = line.substr( 0, line.find_first_of('|') );
00297       line.erase( 0, line.find_first_of('|')+1 );
00298       istringstream stindex(piece);
00299       unsigned index(0);
00300       stindex >> index;
00301       assert( index != 0 );
00302       // read sources and weights
00303       line.erase( 0, line.find_first_of('|')+1 );
00304       piece = line.substr( 0, line.find_first_of('|') );
00305       piece.erase( piece.find_last_not_of(' ')+1 );
00306       while( !piece.empty() ) {
00307         string srcwt;
00308         if( piece.find(',') != string::npos ) {
00309           srcwt = piece.substr( 0, piece.find_first_of(',') );
00310           piece.erase( 0, piece.find_first_of(',')+1 );
00311           piece.erase( piece.find_last_not_of(' ')+1 );
00312           if( piece.empty() ) {
00313             if( !getline(file,piece) ) {
00314               cerr << "Unable to read line " << nline 
00315                    << " in " << fname.c_str() << endl;
00316               return false;
00317             }
00318             else
00319               nline++;
00320           }
00321         }
00322         else {
00323           srcwt = piece;
00324           piece.clear();
00325         }
00326         srcwt.erase( 0, srcwt.find_first_not_of(' ') );
00327         srcwt.erase( srcwt.find_last_not_of(' ')+1 );
00328         unsigned src = atoi(srcwt.substr(0,srcwt.find_first_of(':')).c_str());
00329         double wt = atof(srcwt.substr(srcwt.find_first_of(':')+1).c_str());
00330         assert( src != 0 );
00331         // insert link
00332         Link* link = new Link();
00333         link->weight_ = wt;
00334         Node* target = nodes_[index-1];
00335         Node* source = nodes_[src-1];
00336         target->incoming_.push_back(link);
00337         source->outgoing_.push_back(link);
00338         link->source_ = source;
00339         link->target_ = target;
00340         links_.push_back(link);
00341         readlinks++;
00342       }
00343     }
00344   }
00345 
00346   // success
00347   return true;
00348 }
00349 
00350 
00351 void SprTrainedRBF::printNet(std::ostream& os) const
00352 {
00353   os << "Nodes of RBF network:" << endl;
00354   for( int i=0;i<nodes_.size();i++ ) {
00355     const Node* node = nodes_[i];
00356     os << node->index_ 
00357        << " Type " << int(node->type_)
00358        << " ActFun " << int(node->actFun_)
00359        << " ActRBF " << int(node->actRBF_)
00360        << " OutFun " << int(node->outFun_)
00361        << " activation " << node->act_
00362        << " bias " << node->bias_
00363        << endl;
00364   }
00365   os << "Links of RBF network:" << endl;
00366   for( int i=0;i<links_.size();i++ ) {
00367     const Link* link = links_[i];
00368     os << " Source " << link->source_->index_
00369        << " Target " << link->target_->index_
00370        << " weight " << link->weight_
00371        << endl;
00372   }
00373 }
00374 
00375 
00376 double SprTrainedRBF::response(const std::vector<double>& v) const
00377 {
00378   // loop over hidden nodes and compute RBF values
00379   map<unsigned,double> hidden;// RBF values at hidden nodes
00380   for( int i=0;i<nodes_.size();i++ ) {
00381     const Node* node = nodes_[i];
00382     if( node->type_ == SprNNDefs::HIDDEN ) {
00383       /*
00384       cout << node->index_ << endl;
00385       cout << node->incoming_.size() << " " << v.size() << endl;
00386       */
00387       assert( node->incoming_.size() == v.size() );
00388       double r2 = 0;// r squared
00389       for( int j=0;j<node->incoming_.size();j++ ) {
00390         const Link* link = node->incoming_[j];
00391         assert( link->source_->type_ == SprNNDefs::INPUT );
00392         double x_t = v[link->source_->index_-1] - link->weight_;
00393         r2 += x_t * x_t;
00394       }
00395       hidden.insert(pair<const unsigned,
00396                     double>(node->index_,
00397                             this->rbf(r2,node->bias_,node->actRBF_)));
00398     }
00399   }
00400 
00401   // loop over output nodes and sum linear contributions from input nodes
00402   // and RBF contributions from hidden nodes
00403   vector<double> output;
00404   for( int i=0;i<nodes_.size();i++ ) {
00405     const Node* node = nodes_[i];
00406     if( node->type_ == SprNNDefs::OUTPUT ) {
00407       output.push_back(0);
00408       int imax = output.size()-1;
00409       for( int j=0;j<node->incoming_.size();j++ ) {
00410         const Link* link = node->incoming_[j];
00411         if(      link->source_->type_ == SprNNDefs::INPUT )
00412           output[imax] += v[link->source_->index_-1] * (link->weight_);
00413         else if( link->source_->type_ == SprNNDefs::HIDDEN )
00414           output[imax] += hidden[link->source_->index_] * (link->weight_);
00415       }
00416       output[imax] = this->act(output[imax],node->bias_,node->actFun_);
00417     }
00418   }
00419   assert( !output.empty() );
00420 
00421   // return value of the first output node
00422   return output[0];
00423 }
00424 
00425 
00426 
00427 void SprTrainedRBF::destroy()
00428 {
00429   for( int i=0;i<nodes_.size();i++ )
00430     delete nodes_[i];
00431   for( int i=0;i<links_.size();i++ )
00432     delete links_[i];
00433 }
00434 
00435 
00436 void SprTrainedRBF::correspondence(const SprTrainedRBF& other)
00437 {
00438   // links
00439   map<Link*,Link*> ltol;
00440   for( int i=0;i<other.links_.size();i++ ) {
00441     Link* old = other.links_[i];
00442     Link* link = new Link(*old);
00443     links_.push_back(link);
00444     ltol.insert(pair<Link* const,Link*>(old,link));
00445   }
00446 
00447   // nodes
00448   map<Node*,Node*> nton;
00449   for( int i=0;i<other.nodes_.size();i++ ) {
00450     Node* old = other.nodes_[i];
00451     Node* node = new Node(*old);
00452     nodes_.push_back(node);
00453     nton.insert(pair<Node* const,Node*>(old,node));
00454   }
00455 
00456   // adjust links
00457   for( int i=0;i<nodes_.size();i++ ) {
00458     Node* node = nodes_[i];
00459     for( int j=0;j<(node->incoming_).size();j++ )
00460       (node->incoming_)[j] = ltol[(node->incoming_)[j]];
00461     for( int j=0;j<(node->outgoing_).size();j++ )
00462       (node->outgoing_)[j] = ltol[(node->outgoing_)[j]];
00463   }
00464 
00465   // adjust nodes
00466   for( int i=0;i<links_.size();i++ ) {
00467     Link* link = links_[i];
00468     link->source_ = nton[link->source_];
00469     link->target_ = nton[link->target_];
00470   }
00471 }
00472 
00473 
00474 double SprTrainedRBF::rbf(double r2, double p, ActRBF act) const
00475 {
00476   switch( act )
00477     {
00478     case Gauss :
00479       return exp(-r2*p);
00480       break;
00481     case MultiQ :
00482       return ( (r2+p)>0 ? sqrt(r2+p) : 0 );
00483       break;
00484     case ThinPlate :
00485       return ( (r2>0&&p>0) ? p*p*r2*log(p*sqrt(r2)) : 0 );
00486       break;
00487     }
00488   return 0;
00489 }
00490 
00491 
00492 double SprTrainedRBF::act(double x, double p, SprNNDefs::ActFun act) const
00493 {
00494   switch( act )
00495     {
00496     case SprNNDefs::ID :
00497       return (x+p);
00498       break;
00499     case SprNNDefs::LOGISTIC :
00500       return SprTransformation::logit(x+p);
00501       break;
00502     }
00503   return 0;
00504 }

Generated on Tue Jun 9 17:42:04 2009 for CMSSW by  doxygen 1.5.4