00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedRBF.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00006
00007 #include <map>
00008 #include <utility>
00009 #include <cmath>
00010 #include <fstream>
00011 #include <sstream>
00012 #include <cassert>
00013
00014 using namespace std;
00015
00016
00017 bool SprTrainedRBF::readNet(const char* netfile)
00018 {
00019
00020 string fname = netfile;
00021 ifstream file(fname.c_str());
00022 if( !file ) {
00023 cerr << "Unable to open file " << fname.c_str() << endl;
00024 return false;
00025 }
00026
00027
00028 string line;
00029 unsigned nline = 1;
00030 int nempty = 5;
00031 for( int i=0;i<nempty;i++ ) {
00032 if( !getline(file,line) ) {
00033 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00034 return false;
00035 }
00036 nline++;
00037 }
00038
00039
00040 unsigned nnodes(0), nlinks(0);
00041 if( !getline(file,line) ) {
00042 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00043 return false;
00044 }
00045 else {
00046 nline++;
00047 line.erase( 0, line.find_first_of(':')+1 );
00048 istringstream ist(line);
00049 ist >> nnodes;
00050 if( nnodes == 0 ) {
00051 cerr << "No nodes found in " << fname.c_str() << endl;
00052 return false;
00053 }
00054 }
00055 if( !getline(file,line) ) {
00056 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00057 return false;
00058 }
00059 else {
00060 nline++;
00061 line.erase( 0, line.find_first_of(':')+1 );
00062 istringstream ist(line);
00063 ist >> nlinks;
00064 if( nlinks == 0 ) {
00065 cerr << "No links found in " << fname.c_str() << endl;
00066 return false;
00067 }
00068 }
00069
00070
00071 nempty = 4;
00072 for( int i=0;i<nempty;i++ ) {
00073 if( !getline(file,line) ) {
00074 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00075 return false;
00076 }
00077 nline++;
00078 }
00079
00080
00081 if( !getline(file,line) ) {
00082 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00083 return false;
00084 }
00085 else {
00086 nline++;
00087 line.erase( 0, line.find_first_of(':')+1 );
00088 line.erase( line.find_last_not_of(' ')+1 );
00089 line.erase( 0, line.find_first_not_of(' ') );
00090 if( line != "RadialBasisLearning" ) {
00091 cerr << "Learning function is not RadialBasisLearning!!!" << endl;
00092 return false;
00093 }
00094 }
00095 if( !getline(file,line) ) {
00096 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00097 return false;
00098 }
00099 else {
00100 nline++;
00101 line.erase( 0, line.find_first_of(':')+1 );
00102 line.erase( line.find_last_not_of(' ')+1 );
00103 line.erase( 0, line.find_first_not_of(' ') );
00104 if( line != "Topological_Order" ) {
00105 cerr << "Update function is not Topological_Order!!!" << endl;
00106 return false;
00107 }
00108 }
00109
00110
00111 nempty = 6;
00112 for( int i=0;i<nempty;i++ ) {
00113 if( !getline(file,line) ) {
00114 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00115 return false;
00116 }
00117 nline++;
00118 }
00119
00120
00121 SprNNDefs::ActFun baseAct;
00122 SprNNDefs::OutFun baseOut;
00123 if( !getline(file,line) ) {
00124 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00125 return false;
00126 }
00127 else {
00128 nline++;
00129 for( int i=0;i<5;i++ )
00130 line.erase( 0, line.find_first_of('|')+1 );
00131 string act = line.substr( 0, line.find_first_of('|') );
00132 string out = line.substr( line.find_first_of('|')+1 );
00133 act.erase( 0, act.find_first_not_of(' ') );
00134 act.erase( act.find_last_not_of(' ')+1 );
00135 out.erase( 0, out.find_first_not_of(' ') );
00136 out.erase( out.find_last_not_of(' ')+1 );
00137 if( act == "Act_Logistic" )
00138 baseAct = SprNNDefs::LOGISTIC;
00139 else if( act=="Act_Identity" || act=="Act_IdentityPlusBias" )
00140 baseAct = SprNNDefs::ID;
00141 else {
00142 cerr << "Unknown activation function " << act.c_str()
00143 << " in " << fname.c_str() << endl;
00144 return false;
00145 }
00146 if( out == "Out_Identity" )
00147 baseOut = SprNNDefs::OUTID;
00148 else {
00149 cerr << "Unknown output function " << out.c_str()
00150 << " in " << fname.c_str() << endl;
00151 return false;
00152 }
00153 }
00154
00155
00156 nempty = 7;
00157 for( int i=0;i<nempty;i++ ) {
00158 if( !getline(file,line) ) {
00159 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00160 return false;
00161 }
00162 nline++;
00163 }
00164
00165
00166 for( int i=0;i<nnodes;i++ ) {
00167
00168 if( !getline(file,line) ) {
00169 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00170 return false;
00171 }
00172 else {
00173 nline++;
00174
00175 string piece = line.substr( 0, line.find_first_of('|') );
00176 line.erase( 0, line.find_first_of('|')+1 );
00177 istringstream stindex(piece);
00178 unsigned index(0);
00179 stindex >> index;
00180 assert( index != 0 );
00181
00182 for( int j=0;j<2;j++ )
00183 line.erase( 0, line.find_first_of('|')+1 );
00184 piece = line.substr( 0, line.find_first_of('|') );
00185 line.erase( 0, line.find_first_of('|')+1 );
00186 istringstream stact(piece);
00187 double act = 0;
00188 stact >> act;
00189
00190 piece = line.substr( 0, line.find_first_of('|') );
00191 line.erase( 0, line.find_first_of('|')+1 );
00192 istringstream stbias(piece);
00193 double bias = 0;
00194 stbias >> bias;
00195
00196 piece = line.substr( 0, line.find_first_of('|') );
00197 piece.erase( 0, piece.find_first_not_of(' ') );
00198 piece.erase( piece.find_last_not_of(' ')+1 );
00199 line.erase( 0, line.find_first_of('|')+1 );
00200 SprNNDefs::NodeType type;
00201 if( piece == "i" )
00202 type = SprNNDefs::INPUT;
00203 else if( piece == "h" )
00204 type = SprNNDefs::HIDDEN;
00205 else if( piece == "o" )
00206 type = SprNNDefs::OUTPUT;
00207 else {
00208 cerr << "Unknown node type " << piece.c_str()
00209 << " in " << fname.c_str() << endl;
00210 return false;
00211 }
00212
00213 line.erase( 0, line.find_first_of('|')+1 );
00214 piece = line.substr( 0, line.find_first_of('|') );
00215 piece.erase( 0, piece.find_first_not_of(' ') );
00216 piece.erase( piece.find_last_not_of(' ')+1 );
00217 line.erase( 0, line.find_first_of('|')+1 );
00218 SprNNDefs::ActFun actfun = baseAct;
00219 ActRBF actrbf = Gauss;
00220 if( type == SprNNDefs::HIDDEN ) {
00221 if( piece == "Act_RBF_Gaussian" )
00222 actrbf = Gauss;
00223 else if( piece == "Act_RBF_MultiQuadratic" )
00224 actrbf = MultiQ;
00225 else if( piece == "Act_RBF_ThinPlateSpline" )
00226 actrbf = ThinPlate;
00227 else {
00228 cerr << "Unknown RBF activation function " << piece.c_str()
00229 << " in " << fname.c_str() << endl;
00230 return false;
00231 }
00232 }
00233 else {
00234 if( !piece.empty() ) {
00235 if( piece == "Act_Logistic" )
00236 actfun = SprNNDefs::LOGISTIC;
00237 else if( piece=="Act_Identity" || piece=="Act_IdentityPlusBias" )
00238 actfun = SprNNDefs::ID;
00239 else {
00240 cerr << "Unknown activation function " << piece.c_str()
00241 << " in " << fname.c_str() << endl;
00242 return false;
00243 }
00244 }
00245 }
00246
00247 piece = line.substr( 0, line.find_first_of('|') );
00248 piece.erase( 0, piece.find_first_not_of(' ') );
00249 piece.erase( piece.find_last_not_of(' ')+1 );
00250 line.erase( 0, line.find_first_of('|')+1 );
00251 SprNNDefs::OutFun outfun = baseOut;
00252 if( !piece.empty() ) {
00253 if( piece == "Out_Identity" )
00254 outfun = SprNNDefs::OUTID;
00255 else {
00256 cerr << "Unknown output function " << piece.c_str()
00257 << " in " << fname.c_str() << endl;
00258 return false;
00259 }
00260 }
00261
00262 Node* node = new Node();
00263 node->index_ = index;
00264 node->type_ = type;
00265 node->actFun_ = actfun;
00266 node->outFun_ = outfun;
00267 node->actRBF_ = actrbf;
00268 node->bias_ = bias;
00269 node->act_ = act;
00270 nodes_.push_back(node);
00271 assert( index == nodes_.size() );
00272 }
00273 }
00274
00275
00276 nempty = 7;
00277 for( int i=0;i<nempty;i++ ) {
00278 if( !getline(file,line) ) {
00279 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00280 return false;
00281 }
00282 nline++;
00283 }
00284
00285
00286 unsigned readlinks = 0;
00287 while( readlinks < nlinks ) {
00288
00289 if( !getline(file,line) ) {
00290 cerr << "Error on line " << nline << " in " << fname.c_str() << endl;
00291 return false;
00292 }
00293 else {
00294 nline++;
00295
00296 string piece = line.substr( 0, line.find_first_of('|') );
00297 line.erase( 0, line.find_first_of('|')+1 );
00298 istringstream stindex(piece);
00299 unsigned index(0);
00300 stindex >> index;
00301 assert( index != 0 );
00302
00303 line.erase( 0, line.find_first_of('|')+1 );
00304 piece = line.substr( 0, line.find_first_of('|') );
00305 piece.erase( piece.find_last_not_of(' ')+1 );
00306 while( !piece.empty() ) {
00307 string srcwt;
00308 if( piece.find(',') != string::npos ) {
00309 srcwt = piece.substr( 0, piece.find_first_of(',') );
00310 piece.erase( 0, piece.find_first_of(',')+1 );
00311 piece.erase( piece.find_last_not_of(' ')+1 );
00312 if( piece.empty() ) {
00313 if( !getline(file,piece) ) {
00314 cerr << "Unable to read line " << nline
00315 << " in " << fname.c_str() << endl;
00316 return false;
00317 }
00318 else
00319 nline++;
00320 }
00321 }
00322 else {
00323 srcwt = piece;
00324 piece.clear();
00325 }
00326 srcwt.erase( 0, srcwt.find_first_not_of(' ') );
00327 srcwt.erase( srcwt.find_last_not_of(' ')+1 );
00328 unsigned src = atoi(srcwt.substr(0,srcwt.find_first_of(':')).c_str());
00329 double wt = atof(srcwt.substr(srcwt.find_first_of(':')+1).c_str());
00330 assert( src != 0 );
00331
00332 Link* link = new Link();
00333 link->weight_ = wt;
00334 Node* target = nodes_[index-1];
00335 Node* source = nodes_[src-1];
00336 target->incoming_.push_back(link);
00337 source->outgoing_.push_back(link);
00338 link->source_ = source;
00339 link->target_ = target;
00340 links_.push_back(link);
00341 readlinks++;
00342 }
00343 }
00344 }
00345
00346
00347 return true;
00348 }
00349
00350
00351 void SprTrainedRBF::printNet(std::ostream& os) const
00352 {
00353 os << "Nodes of RBF network:" << endl;
00354 for( int i=0;i<nodes_.size();i++ ) {
00355 const Node* node = nodes_[i];
00356 os << node->index_
00357 << " Type " << int(node->type_)
00358 << " ActFun " << int(node->actFun_)
00359 << " ActRBF " << int(node->actRBF_)
00360 << " OutFun " << int(node->outFun_)
00361 << " activation " << node->act_
00362 << " bias " << node->bias_
00363 << endl;
00364 }
00365 os << "Links of RBF network:" << endl;
00366 for( int i=0;i<links_.size();i++ ) {
00367 const Link* link = links_[i];
00368 os << " Source " << link->source_->index_
00369 << " Target " << link->target_->index_
00370 << " weight " << link->weight_
00371 << endl;
00372 }
00373 }
00374
00375
00376 double SprTrainedRBF::response(const std::vector<double>& v) const
00377 {
00378
00379 map<unsigned,double> hidden;
00380 for( int i=0;i<nodes_.size();i++ ) {
00381 const Node* node = nodes_[i];
00382 if( node->type_ == SprNNDefs::HIDDEN ) {
00383
00384
00385
00386
00387 assert( node->incoming_.size() == v.size() );
00388 double r2 = 0;
00389 for( int j=0;j<node->incoming_.size();j++ ) {
00390 const Link* link = node->incoming_[j];
00391 assert( link->source_->type_ == SprNNDefs::INPUT );
00392 double x_t = v[link->source_->index_-1] - link->weight_;
00393 r2 += x_t * x_t;
00394 }
00395 hidden.insert(pair<const unsigned,
00396 double>(node->index_,
00397 this->rbf(r2,node->bias_,node->actRBF_)));
00398 }
00399 }
00400
00401
00402
00403 vector<double> output;
00404 for( int i=0;i<nodes_.size();i++ ) {
00405 const Node* node = nodes_[i];
00406 if( node->type_ == SprNNDefs::OUTPUT ) {
00407 output.push_back(0);
00408 int imax = output.size()-1;
00409 for( int j=0;j<node->incoming_.size();j++ ) {
00410 const Link* link = node->incoming_[j];
00411 if( link->source_->type_ == SprNNDefs::INPUT )
00412 output[imax] += v[link->source_->index_-1] * (link->weight_);
00413 else if( link->source_->type_ == SprNNDefs::HIDDEN )
00414 output[imax] += hidden[link->source_->index_] * (link->weight_);
00415 }
00416 output[imax] = this->act(output[imax],node->bias_,node->actFun_);
00417 }
00418 }
00419 assert( !output.empty() );
00420
00421
00422 return output[0];
00423 }
00424
00425
00426
00427 void SprTrainedRBF::destroy()
00428 {
00429 for( int i=0;i<nodes_.size();i++ )
00430 delete nodes_[i];
00431 for( int i=0;i<links_.size();i++ )
00432 delete links_[i];
00433 }
00434
00435
00436 void SprTrainedRBF::correspondence(const SprTrainedRBF& other)
00437 {
00438
00439 map<Link*,Link*> ltol;
00440 for( int i=0;i<other.links_.size();i++ ) {
00441 Link* old = other.links_[i];
00442 Link* link = new Link(*old);
00443 links_.push_back(link);
00444 ltol.insert(pair<Link* const,Link*>(old,link));
00445 }
00446
00447
00448 map<Node*,Node*> nton;
00449 for( int i=0;i<other.nodes_.size();i++ ) {
00450 Node* old = other.nodes_[i];
00451 Node* node = new Node(*old);
00452 nodes_.push_back(node);
00453 nton.insert(pair<Node* const,Node*>(old,node));
00454 }
00455
00456
00457 for( int i=0;i<nodes_.size();i++ ) {
00458 Node* node = nodes_[i];
00459 for( int j=0;j<(node->incoming_).size();j++ )
00460 (node->incoming_)[j] = ltol[(node->incoming_)[j]];
00461 for( int j=0;j<(node->outgoing_).size();j++ )
00462 (node->outgoing_)[j] = ltol[(node->outgoing_)[j]];
00463 }
00464
00465
00466 for( int i=0;i<links_.size();i++ ) {
00467 Link* link = links_[i];
00468 link->source_ = nton[link->source_];
00469 link->target_ = nton[link->target_];
00470 }
00471 }
00472
00473
00474 double SprTrainedRBF::rbf(double r2, double p, ActRBF act) const
00475 {
00476 switch( act )
00477 {
00478 case Gauss :
00479 return exp(-r2*p);
00480 break;
00481 case MultiQ :
00482 return ( (r2+p)>0 ? sqrt(r2+p) : 0 );
00483 break;
00484 case ThinPlate :
00485 return ( (r2>0&&p>0) ? p*p*r2*log(p*sqrt(r2)) : 0 );
00486 break;
00487 }
00488 return 0;
00489 }
00490
00491
00492 double SprTrainedRBF::act(double x, double p, SprNNDefs::ActFun act) const
00493 {
00494 switch( act )
00495 {
00496 case SprNNDefs::ID :
00497 return (x+p);
00498 break;
00499 case SprNNDefs::LOGISTIC :
00500 return SprTransformation::logit(x+p);
00501 break;
00502 }
00503 return 0;
00504 }