CMS 3D CMS Logo

SprStdBackprop.cc

Go to the documentation of this file.
00001 //$Id: SprStdBackprop.cc,v 1.2 2007/09/21 22:32:10 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprStdBackprop.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00012 
00013 #include <stdio.h>
00014 #include <cmath>
00015 #include <iomanip>
00016 #include <sstream>
00017 #include <utility>
00018 #include <cassert>
00019 
00020 using namespace std;
00021 
00022 
00023 SprStdBackprop::~SprStdBackprop()
00024 {
00025   if( ownLoss_ ) {
00026     delete loss_;
00027     loss_ = 0;
00028     ownLoss_ = false;
00029   }  
00030 }
00031 
00032 SprStdBackprop::SprStdBackprop(SprAbsFilter* data)
00033   :
00034   SprAbsClassifier(data),
00035   structure_(),
00036   cls0_(0),
00037   cls1_(1),
00038   cycles_(0),
00039   eta_(0.1),
00040   configured_(false),
00041   initialized_(false),
00042   initEta_(0.1),
00043   initPoints_(data->size()),
00044   rndm_(),
00045   permu_(data->size()),
00046   allowPermu_(true),
00047   nNodes_(0),
00048   nLinks_(0),
00049   nodeType_(),
00050   nodeActFun_(),
00051   nodeAct_(),
00052   nodeOut_(),
00053   nodeNInputLinks_(),
00054   nodeFirstInputLink_(),
00055   linkSource_(),
00056   nodeBias_(),
00057   linkWeight_(),
00058   cut_(SprUtils::lowerBound(0.5)),
00059   valData_(0),
00060   valPrint_(0),
00061   loss_(0),
00062   ownLoss_(false),
00063   initialDataWeights_()
00064 {
00065   this->setClasses();
00066 }
00067 
00068 
00069 SprStdBackprop::SprStdBackprop(SprAbsFilter* data, 
00070                                unsigned cycles,
00071                                double eta)
00072   :
00073   SprAbsClassifier(data),
00074   structure_(),
00075   cls0_(0),
00076   cls1_(1),
00077   cycles_(cycles),
00078   eta_(eta),
00079   configured_(false),
00080   initialized_(false),
00081   initEta_(0.1),
00082   initPoints_(data->size()),
00083   rndm_(),
00084   permu_(data->size()),
00085   allowPermu_(true),
00086   nNodes_(0),
00087   nLinks_(0),
00088   nodeType_(),
00089   nodeActFun_(),
00090   nodeAct_(),
00091   nodeOut_(),
00092   nodeNInputLinks_(),
00093   nodeFirstInputLink_(),
00094   linkSource_(),
00095   nodeBias_(),
00096   linkWeight_(),
00097   cut_(SprUtils::lowerBound(0.5)),
00098   valData_(0),
00099   valPrint_(0),
00100   loss_(0),
00101   ownLoss_(false),
00102   initialDataWeights_()
00103 {
00104   this->setClasses();
00105   cout << "StdBackprop initialized with classes " << cls0_ << " " << cls1_
00106        << " nCycles=" << cycles_ << " LearningRate=" << eta_ << endl;
00107 }
00108 
00109 
00110 SprStdBackprop::SprStdBackprop(SprAbsFilter* data, 
00111                                const char* structure,
00112                                unsigned cycles,
00113                                double eta)
00114   :
00115   SprAbsClassifier(data),
00116   structure_(structure),
00117   cls0_(0),
00118   cls1_(1),
00119   cycles_(cycles),
00120   eta_(eta),
00121   configured_(false),
00122   initialized_(false),
00123   initEta_(0.1),
00124   initPoints_(data->size()),
00125   rndm_(),
00126   permu_(data->size()),
00127   allowPermu_(true),
00128   nNodes_(0),
00129   nLinks_(0),
00130   nodeType_(),
00131   nodeActFun_(),
00132   nodeAct_(),
00133   nodeOut_(),
00134   nodeNInputLinks_(),
00135   nodeFirstInputLink_(),
00136   linkSource_(),
00137   nodeBias_(),
00138   linkWeight_(),
00139   cut_(SprUtils::lowerBound(0.5)),
00140   valData_(0),
00141   valPrint_(0),
00142   loss_(0),
00143   ownLoss_(false),
00144   initialDataWeights_()
00145 {
00146   this->setClasses();
00147   bool status = this->createNet();
00148   assert( status );
00149   cout << "StdBackprop initialized with classes " << cls0_ << " " << cls1_
00150        << " nCycles=" << cycles_ << " structure=" << structure_.c_str()
00151        << " LearningRate=" << eta_ << endl;
00152 }
00153 
00154 
00155 SprTrainedStdBackprop* SprStdBackprop::makeTrained() const 
00156 {
00157   SprTrainedStdBackprop* t = new SprTrainedStdBackprop(structure_.c_str(),
00158                                                        nodeType_,nodeActFun_,
00159                                                        nodeNInputLinks_,
00160                                                        nodeFirstInputLink_,
00161                                                        linkSource_,nodeBias_,
00162                                                        linkWeight_);
00163   t->setCut(cut_);
00164 
00165   // vars
00166   vector<string> vars;
00167   data_->vars(vars);
00168   t->setVars(vars);
00169 
00170   // exit
00171   return t;
00172 }
00173 
00174 
00175 bool SprStdBackprop::createNet() 
00176 {
00177   // init
00178   configured_ = false;
00179 
00180   // sanity check
00181   if( structure_.empty() ) {
00182     cerr << "No network structure specified. Exiting." << endl;
00183     return false;
00184   }
00185 
00186   // parse
00187   vector<vector<int> > layers;
00188   SprStringParser::parseToInts(structure_.c_str(),layers);
00189 
00190   // check output
00191   if( layers.size() < 3 ) {
00192     cerr << "Not enough layers in the neural net: " << layers.size() 
00193          << " for structure " << structure_.c_str() << endl;
00194     return false;
00195   }
00196   if( layers[0].size()!=1 || layers[0][0]!=data_->dim() ) {
00197     cerr << "Size of the input layer " << layers[0][0]
00198          << " must be equal to the dimensionality of input data " 
00199          << data_->dim() << endl;
00200     return false;
00201   }
00202   for( int i=1;i<layers.size()-1;i++ ) {
00203     if( layers[i].size()!=1 || layers[i][0]<=0 ) {
00204       cerr << "Error in specifying hidden layer " << i << endl;
00205       return false;
00206     }
00207   }
00208   if( layers[layers.size()-1].size()!=1 || layers[layers.size()-1][0]!=1 ) {
00209     cerr << "This NN implementation can only handle "
00210          << "one node in the output layer." << endl;
00211     return false;
00212   }
00213 
00214   // create net
00215   nNodes_ = 0;
00216   for( int i=0;i<layers.size();i++ ) nNodes_ += layers[i][0];
00217   nodeType_.clear(); nodeType_.resize(nNodes_,SprNNDefs::INPUT);
00218   nodeActFun_.clear(); nodeActFun_.resize(nNodes_,SprNNDefs::ID);
00219   nodeAct_.clear(); nodeAct_.resize(nNodes_,0);
00220   nodeOut_.clear(); nodeOut_.resize(nNodes_,0);
00221   nodeNInputLinks_.clear(); nodeNInputLinks_.resize(nNodes_,0);
00222   nodeFirstInputLink_.clear(); nodeFirstInputLink_.resize(nNodes_,-1);
00223   nodeBias_.clear(); nodeBias_.resize(nNodes_,0);
00224   int index = 0;
00225 
00226   // input nodes
00227   // keep this commented out - this is just for clarity but in fact
00228   //   this code does nothing
00229   /*
00230   for( int i=0;i<layers[0][0];i++ ) {
00231     nodeType_[index]           = SprNNDefs::INPUT;
00232     nodeActFun_[index]         = SprNNDefs::ID;
00233     nodeNInputLinks_[index]    = 0; 
00234     nodeFirstInputLink_[index] = -1;
00235     index++;
00236   }
00237   */
00238 
00239   // hidden nodes
00240   index = layers[0][0];
00241   int firstLink = 0;
00242   linkSource_.clear();
00243   int nstart(0), nend(0);// flat node indices for the previous layer
00244   for( int i=1;i<layers.size()-1;i++ ) {
00245     nstart = nend;
00246     nend += layers[i-1][0];
00247     for( int j=0;j<layers[i][0];j++ ) {
00248       nodeType_[index]           = SprNNDefs::HIDDEN;
00249       nodeActFun_[index]         = SprNNDefs::LOGISTIC;
00250       nodeNInputLinks_[index]    = layers[i-1][0]; 
00251       nodeFirstInputLink_[index] = firstLink;
00252       firstLink += layers[i-1][0];
00253       index++;
00254       for( int n=nstart;n<nend;n++ ) linkSource_.push_back(n);
00255     }
00256   }
00257 
00258   // output nodes
00259   assert( index == (nNodes_-1) );
00260   nodeType_[index]           = SprNNDefs::OUTPUT;
00261   nodeActFun_[index]         = SprNNDefs::LOGISTIC;
00262   nodeNInputLinks_[index]    = layers[layers.size()-2][0]; 
00263   nodeFirstInputLink_[index] = firstLink;
00264   nstart = nend;
00265   nend += layers[layers.size()-2][0];
00266   for( int n=nstart;n<nend;n++ ) linkSource_.push_back(n);
00267 
00268   // links
00269   nLinks_ = linkSource_.size();
00270   linkWeight_.clear(); linkWeight_.resize(nLinks_,0);
00271 
00272   // exit
00273   configured_ = true;
00274   return true;
00275 }
00276 
00277 
00278 bool SprStdBackprop::init(double eta, unsigned nPoints)
00279 {
00280   if( initialized_ ) return true;
00281   initEta_ = eta;
00282   initPoints_ = nPoints;
00283   unsigned valPrint = valPrint_;
00284   valPrint_ = 0;
00285   initialized_ = this->doTrain(initPoints_,1,initEta_,true,1);
00286   valPrint_ = valPrint;
00287   return initialized_;
00288 }
00289 
00290 
00291 bool SprStdBackprop::train(int verbose)
00292 {
00293   // sanity check
00294   if( cycles_ == 0 ) {
00295     cout << "No training cycles for neural net requested. " 
00296          << "Will exit without training." << endl;
00297     return true;
00298   }
00299   if( !configured_ ) {
00300     cerr << "Neural net configuration not specified." << endl;
00301     return false;
00302   }
00303 
00304   // initialize
00305   if( !initialized_ ) {
00306     if( verbose > 0 ) {
00307       cout << "Initializing network with learning rate " << initEta_ 
00308            << " and number of points for initialization " << initPoints_ 
00309            << endl;
00310     }
00311     if( !this->init(initEta_,initPoints_) ) {
00312       cerr << "Unable to initialize network." << endl;
00313       return false;
00314     }
00315     if( verbose > 0 )
00316       cout << "Neural net initialized." << endl;
00317   }
00318 
00319   // train
00320   return this->doTrain(data_->size(),cycles_,eta_,false,verbose);
00321 }
00322 
00323 
00324 bool SprStdBackprop::doTrain(unsigned nPoints, unsigned nCycles, 
00325                              double eta, bool randomizeEta, int verbose)
00326 {
00327   // normalize data weights
00328   data_->weights(initialDataWeights_);
00329   vector<SprClass> classes(2);
00330   classes[0] = cls0_; classes[1] = cls1_;
00331   double wtot = data_->ptsInClass(cls0_) + data_->ptsInClass(cls1_);
00332   data_->normalizeWeights(classes,wtot);
00333 
00334   // permute input events
00335   unsigned size = data_->size();
00336   if( nPoints==0 || nPoints>size ) {
00337     if( verbose > 1 ) {
00338       cout << "Resetting the number of training points "
00339            << "to the max number of points available." << endl;
00340     }
00341     nPoints = size;
00342   }
00343   vector<unsigned> indices;
00344   if( allowPermu_ ) {
00345     if( !permu_.sequence(indices) ) {
00346       cerr << "Unable to permute input indices for training." << endl;
00347       return this->prepareExit(false);
00348     }
00349   }
00350   else {
00351     for( unsigned i=0;i<nPoints;i++ ) indices.push_back(i);
00352   }
00353 
00354   // validate before training starts
00355   if( valPrint_!=0 ) {
00356     if( !this->printValidation(0) ) {
00357       cerr << "Unable to print out validation data." << endl;
00358       return this->prepareExit(false);
00359     }
00360   }
00361 
00362   // train
00363   for( int ncycle=1;ncycle<=nCycles;ncycle++ ) {
00364     // message
00365     if( verbose > 0 ) {
00366       if( ncycle%10 == 0 )
00367         cout << "Training neural net at cycle " << ncycle << endl;
00368     }
00369 
00370     // do two passes of propagation
00371     for( int i=0;i<nPoints;i++ ) {
00372       unsigned ipt = indices[i];
00373       const SprPoint* p = (*data_)[ipt];
00374       int cls = -1;
00375       if(      p->class_ == cls0_ )
00376         cls = 0;
00377       else if( p->class_ == cls1_ )
00378         cls = 1;
00379       else
00380         continue;
00381 
00382       // forward pass
00383       double output = this->forward(p->x_);
00384 
00385       // generate random learning factors for first cycle
00386       double w = data_->w(ipt);
00387       vector<double> etaV(nLinks_+1,w*eta);
00388       if( randomizeEta ) {
00389         double* r = new double [nLinks_+1];
00390         rndm_.sequence(r,nLinks_);
00391         for( int j=0;j<=nLinks_;j++ ) etaV[j] = eta*r[j];
00392         delete [] r;
00393       }
00394 
00395       // backward pass
00396       if( !this->backward(cls,output,etaV) ) {
00397         cerr << "Unable to backward-propagate at cycle " << ncycle << endl;
00398         return this->prepareExit(false);
00399       }
00400     }// end of do two passes of propagation
00401 
00402     // validate
00403     if( valPrint_!=0 && (ncycle%valPrint_)==0 ) {
00404       if( !this->printValidation(ncycle) ) {
00405         cerr << "Unable to print out validation data." << endl;
00406         return this->prepareExit(false);
00407       }
00408     }
00409   }
00410 
00411   // exit
00412   return this->prepareExit(true);
00413 }
00414 
00415 
00416 double SprStdBackprop::forward(const std::vector<double>& v)
00417 {
00418   // Initialize and process input nodes
00419   nodeOut_.clear(); nodeOut_.resize(nNodes_,0);
00420   int d = 0;
00421   for( int i=0;i<nNodes_;i++ ) {
00422     if( nodeType_[i] == SprNNDefs::INPUT )
00423       nodeOut_[i] = v[d++];
00424     else
00425       break;
00426   }
00427 
00428   // Process hidden and output nodes
00429   for( int i=0;i<nNodes_;i++ ) {
00430     nodeAct_[i] = 0;
00431     if( nodeNInputLinks_[i] > 0 ) {
00432       for( int j=nodeFirstInputLink_[i];
00433            j<nodeFirstInputLink_[i]+nodeNInputLinks_[i];j++ ) {
00434         nodeAct_[i] += nodeOut_[linkSource_[j]] * linkWeight_[j];
00435       }
00436       nodeOut_[i] = this->activate(nodeAct_[i]+nodeBias_[i],nodeActFun_[i]);
00437     }
00438   }
00439 
00440   // Find output node and return result
00441   return nodeOut_[nNodes_-1];
00442 }
00443 
00444 
00445 bool SprStdBackprop::backward(int cls, double output, 
00446                               const std::vector<double>& etaV)
00447 {
00448   // make temp copies
00449   vector<double> tempLinkWeight(linkWeight_);
00450   vector<double> tempNodeBias(nodeBias_);
00451 
00452   // reset gradients
00453   vector<double> nodeGradient(nNodes_,0);
00454 
00455   // gradient in the output node
00456   nodeGradient[nNodes_-1] = (double(cls)-output) *
00457     this->act_deriv(nodeAct_[nNodes_-1]+nodeBias_[nNodes_-1],
00458                     nodeActFun_[nNodes_-1]);
00459   nodeBias_[nNodes_-1] += etaV[nLinks_] * nodeGradient[nNodes_-1];
00460 
00461   // propagate backwards thru hidden nodes
00462   for( int target=nNodes_-1;target>=0;target-- ) {
00463     if( nodeNInputLinks_[target] > 0 ) {
00464       for( int link=nodeFirstInputLink_[target];
00465            link<nodeFirstInputLink_[target]+nodeNInputLinks_[target];
00466            link++ ) {
00467         int source = linkSource_[link];
00468         linkWeight_[link] += etaV[link] 
00469           * nodeGradient[target] * nodeOut_[source];
00470         if( nodeType_[source] == SprNNDefs::HIDDEN ) {
00471           nodeGradient[source] += 
00472             this->act_deriv(nodeAct_[source]+tempNodeBias[source],
00473                             nodeActFun_[source]) 
00474             * tempLinkWeight[link] * nodeGradient[target];
00475           nodeBias_[source] += etaV[link] * nodeGradient[source];
00476         }
00477       }
00478     }
00479   }
00480 
00481   // exit
00482   return true;
00483 }
00484 
00485 
00486 bool SprStdBackprop::reset()
00487 {
00488   initialized_ = false;
00489   nodeBias_.clear(); nodeBias_.resize(nNodes_,0);
00490   nodeAct_.clear(); nodeAct_.resize(nNodes_,0);
00491   nodeOut_.clear(); nodeOut_.resize(nNodes_,0);
00492   linkWeight_.clear(); linkWeight_.resize(nLinks_,0);
00493   return true;
00494 }
00495 
00496 
00497 bool SprStdBackprop::setData(SprAbsFilter* data)
00498 {
00499   assert( data != 0 );
00500   data_ = data;
00501   return this->reset();
00502 }
00503 
00504 
00505 void SprStdBackprop::print(std::ostream& os) const 
00506 {
00507   os << "Trained StdBackprop with configuration " 
00508      << structure_.c_str() << " " << SprVersion << endl; 
00509   os << "Activation functions: Identity=1, Logistic=2" << endl;
00510   os << "Cut: " << cut_.size();
00511   for( int i=0;i<cut_.size();i++ )
00512     os << "      " << cut_[i].first << " " << cut_[i].second;
00513   os << endl;
00514   os << "Nodes: " << nNodes_ << endl;
00515   for( int i=0;i<nNodes_;i++ ) {
00516     char nodeType;
00517     switch( nodeType_[i] )
00518       {
00519       case SprNNDefs::INPUT :
00520         nodeType = 'I';
00521         break;
00522       case SprNNDefs::HIDDEN :
00523         nodeType = 'H';
00524         break;
00525       case SprNNDefs::OUTPUT :
00526         nodeType = 'O';
00527         break;
00528       }
00529     int actFun = 0;
00530     switch( nodeActFun_[i] )
00531       {
00532       case SprNNDefs::ID :
00533         actFun = 1;
00534         break;
00535       case SprNNDefs::LOGISTIC :
00536         actFun = 2;
00537         break;
00538       }
00539     os << setw(6) << i
00540        << "    Type: "           << nodeType
00541        << "    ActFunction: "    << actFun
00542        << "    NInputLinks: "    << setw(6) << nodeNInputLinks_[i]
00543        << "    FirstInputLink: " << setw(6) << nodeFirstInputLink_[i]
00544        << "    Bias: "           << nodeBias_[i]
00545        << endl;
00546   }
00547   os << "Links: " << nLinks_ << endl;
00548   for( int i=0;i<nLinks_;i++ ) {
00549     os << setw(6) << i
00550        << "    Source: " << setw(6) << linkSource_[i]
00551        << "    Weight: " << linkWeight_[i]
00552        << endl;
00553   }
00554 }
00555 
00556 
00557 void SprStdBackprop::setClasses()
00558 {
00559   vector<SprClass> classes;
00560   data_->classes(classes);
00561   int size = classes.size();
00562   if( size > 0 ) cls0_ = classes[0];
00563   if( size > 1 ) cls1_ = classes[1];
00564   cout << "Classes for StdBackprop are set to " 
00565        << cls0_ << " " << cls1_ << endl;
00566 }
00567 
00568 
00569 bool SprStdBackprop::setValidation(const SprAbsFilter* valData, 
00570                                    unsigned valPrint,
00571                                    SprAverageLoss* loss)
00572 {
00573   // set
00574   valData_ = valData;
00575   valPrint_ = valPrint;
00576 
00577   // if no loss specified, use quadratic by default
00578   loss_ = loss;
00579   ownLoss_ = false;
00580   if( loss_ == 0 ) {
00581     loss_ = new SprAverageLoss(&SprLoss::quadratic);
00582     ownLoss_ = true;
00583   }
00584 
00585   // exit
00586   return true;
00587 }
00588 
00589 
00590 bool SprStdBackprop::printValidation(unsigned cycle)
00591 {
00592   // reset loss
00593   assert( loss_ != 0 );
00594   loss_->reset();
00595 
00596   // make trained NN
00597   SprTrainedStdBackprop* t = this->makeTrained();
00598 
00599   // loop through validation data
00600   for( int i=0;i<valData_->size();i++ ) {
00601     const SprPoint* p = (*valData_)[i];
00602     double r = t->response(p->x_);
00603     double w = valData_->w(i);
00604     if( p->class_!=cls0_ && p->class_!=cls1_ ) w = 0;
00605     if(      p->class_ == cls0_ )
00606       loss_->update(0,r,w);
00607     else if( p->class_ == cls1_ )
00608       loss_->update(1,r,w);
00609   }
00610 
00611   // compute fom
00612   cout << "Validation Loss=" << loss_->value()
00613        << " at cycle " << cycle << endl;
00614 
00615   // exit
00616   return true;
00617 }
00618 
00619 
00620 double SprStdBackprop::activate(double x, SprNNDefs::ActFun f) const 
00621 {
00622   switch (f) 
00623     {
00624     case SprNNDefs::ID :
00625       return x;
00626       break;
00627     case SprNNDefs::LOGISTIC :
00628       return SprTransformation::logit(x);
00629       break;
00630     default :
00631       cerr << "Unknown activation function " 
00632            << f << " in SprTrainedStdBackprop::activate" << endl;
00633       return 0;
00634     }
00635   return 0;
00636 }
00637 
00638 
00639 double SprStdBackprop::act_deriv(double x, SprNNDefs::ActFun f) const 
00640 {
00641   switch (f) 
00642     {
00643     case SprNNDefs::ID :
00644       return 1;
00645       break;
00646     case SprNNDefs::LOGISTIC :
00647       return SprTransformation::logit_deriv(x);
00648       break;
00649     default :
00650       cerr << "Unknown activation function " 
00651            << f << " in SprTrainedStdBackprop::activate" << endl;
00652       return 0;
00653     }
00654   return 0;
00655 }
00656 
00657 
00658 bool SprStdBackprop::prepareExit(bool status)
00659 {
00660   data_->setWeights(initialDataWeights_);
00661   return status;
00662 }
00663 
00664 
00665 bool SprStdBackprop::readSNNS(const char* netfile) 
00666 {
00667   // sanity check and init
00668   if( 0 == netfile ) return false;
00669   structure_ = "Unknown";
00670   configured_ = false;
00671   initialized_ = false;
00672   string nfile = netfile;
00673   bool success = false;
00674 
00675   // open file
00676   ifstream file(nfile.c_str());
00677   if( !file ) {
00678     cerr << "Unable to open file " << nfile.c_str() << endl;
00679     return false;
00680   }
00681 
00682   // Read header of network definition file
00683   string line;
00684   unsigned nLine = 0;
00685   nLine++;
00686   nNodes_ = 0;
00687   while( getline(file,line) ) {
00688     const char* searchfor = "no. of units :";
00689     size_t pos = line.find(searchfor);
00690     if( pos != string::npos ) {
00691       line.erase(0,pos+strlen(searchfor)+1);
00692       istringstream istnodes(line);
00693       istnodes >> nNodes_;
00694       break;
00695     }
00696     nLine++;
00697   }
00698   if( nNodes_ <= 0 ) {
00699     cerr << "Can't find units line in file " << nfile.c_str() << endl;
00700     return false;
00701   }
00702   nLine++;
00703   if( !getline(file,line) ) {
00704     cerr << "Cannot read from " << nfile.c_str() << " line " << nLine << endl;
00705     return false;
00706   }
00707   nLinks_ = 0;
00708   const char* searchfor = "no. of connections :";
00709   size_t pos = line.find(searchfor);
00710   if( pos != string::npos ) {
00711     line.erase(0,pos+strlen(searchfor)+1);
00712     istringstream istconns(line);
00713     istconns >> nLinks_;
00714   }
00715   if( nLinks_ <= 0 ) {
00716     cerr << "Can't find connections line in file " << nfile.c_str() << endl;
00717     return false;
00718   }
00719   //  cout << "Nodes and links: " << nNodes_ << " " << nLinks_ << endl;
00720 
00721   // Allocate space for node and link data
00722   nodeType_.clear(); nodeType_.resize(nNodes_,SprNNDefs::INPUT);
00723   nodeActFun_.clear(); nodeActFun_.resize(nNodes_,SprNNDefs::ID);
00724   nodeAct_.clear(); nodeAct_.resize(nNodes_,0);
00725   nodeOut_.clear(); nodeOut_.resize(nNodes_,0);
00726   nodeNInputLinks_.clear(); nodeNInputLinks_.resize(nNodes_,0);
00727   nodeFirstInputLink_.clear(); nodeFirstInputLink_.resize(nNodes_,-1);
00728   nodeBias_.clear(); nodeBias_.resize(nNodes_,0);
00729   linkSource_.clear(); linkSource_.resize(nLinks_,0);
00730   linkWeight_.clear(); linkWeight_.resize(nLinks_,0);
00731     
00732   // Here we should check that we are reading the correct type of network,
00733   // i.e. one using the Act_Logistic activation function ...
00734 
00735   //
00736   // Read node information
00737   //
00738   nLine++;
00739   bool found = false;
00740   while( getline(file,line) ) {
00741     size_t pos = line.find("unit definition section :");
00742     if( pos != string::npos ) {
00743       found = true;
00744       break;
00745     }
00746     nLine++;
00747   }
00748   if( !found ) {
00749     cerr << "Can't find unit definition section in file " 
00750          << nfile.c_str() << endl;
00751     return false;
00752   }
00753   // skip 3 lines
00754   for( int i=0;i<3;i++ ) {
00755     nLine++;
00756     if( !getline(file,line) ) {
00757       cerr << "Cannot read from " << nfile.c_str() 
00758            << " line " << nLine << endl;
00759       return false;
00760     }
00761   }
00762   // read nodes one by one
00763   unsigned nOutput = 0;
00764   for( int node=0;node<nNodes_;node++ ) {
00765     nLine++;
00766     if( !getline(file,line) ) {
00767       cerr << "Cannot read from " << nfile.c_str() 
00768            << " line " << nLine << endl;
00769       return false;
00770     }
00771     istringstream istnode(line);
00772     int id = 0;
00773     istnode >> id;
00774     if( id != (node+1) ) {
00775       cerr << "Node ID does not match on line " << nLine << endl;
00776       return false;
00777     }
00778     char c;
00779     double dummy;
00780     for( int i=0;i<3;i++ ) istnode >> c;
00781     istnode >> dummy >> c >> nodeBias_[node] >> c;
00782     istnode >> c;
00783     switch( c ) 
00784       {
00785       case 'i' :
00786         nodeType_[node] = SprNNDefs::INPUT;
00787         nodeActFun_[node] = SprNNDefs::ID;
00788         break;
00789       case 'h' :
00790         nodeType_[node] = SprNNDefs::HIDDEN;
00791         nodeActFun_[node] = SprNNDefs::LOGISTIC;
00792         break;
00793       case 'o' :
00794         nodeType_[node] = SprNNDefs::OUTPUT;
00795         nodeActFun_[node] = SprNNDefs::LOGISTIC;
00796         nOutput++;
00797         break;
00798       default :
00799         cerr << "Unknown node type on line " << nLine << endl;
00800         return false;
00801       }
00802   }
00803   if( nOutput > 1 ) {
00804     cerr << "More than one output node cannot be handled "
00805          << "by this implementation" << endl;
00806     return false;
00807   }
00808   //  cout << "Unit definition section has been read " << nLine << endl;
00809 
00810   //
00811   // Read link information
00812   //
00813   nLine++;
00814   found = false;
00815   while( getline(file,line) ) {
00816     size_t pos = line.find("connection definition section :");
00817     if( pos != string::npos ) {
00818       found = true;
00819       break;
00820     }
00821     nLine++;
00822   }
00823   if( !found ) {
00824     cerr << "Can't find connection definition section in file " 
00825          << nfile.c_str() << endl;
00826     return false;
00827   }
00828   // skip 3 lines
00829   for( int i=0;i<3;i++ ) {
00830     nLine++;
00831     if( !getline(file,line) ) {
00832       cerr << "Cannot read from " << nfile.c_str() 
00833            << " line " << nLine << endl;
00834       return false;
00835     }
00836   }
00837   // read links one by one
00838   int link = 0;
00839   string prevLine;
00840   while( getline(file,line) ) {
00841     nLine++;
00842     // if the last symbol is comma, continue to next line
00843     if( line.at(line.find_last_not_of(' ')) == ',' ) {
00844       prevLine = line;
00845       continue;
00846     }
00847     line = prevLine+line;
00848     prevLine = "";
00849     // get target
00850     size_t separ_pos = line.find_first_of('|');
00851     if( separ_pos == string::npos ) {
00852       cerr << "Cannot read from " << nfile.c_str() 
00853            << " line " << nLine << endl;
00854       return false;
00855     }
00856     string target_str = line.substr(0,separ_pos);
00857     line.erase(0,separ_pos+1);
00858     int target = atoi(target_str.c_str());
00859     if( target<=0 || target>nNodes_ ) {
00860       cerr << "Unable to read target node from "
00861            << nfile.c_str() << " on line " << nLine 
00862            << " : nNodes=" << nNodes_ << " target=" << target << endl;
00863       return false;
00864     }
00865     target--;// offset by 1 to start numbering from 0 instead of 1
00866     // assign first link for the target
00867     nodeFirstInputLink_[target] = link;
00868     // skip one field
00869     separ_pos = line.find_first_of('|');
00870     if( separ_pos == string::npos ) {
00871       cerr << "Cannot read from " << nfile.c_str() 
00872            << " line " << nLine << endl;
00873       return false;
00874     }
00875     // get source
00876     string sources_str = line.substr(separ_pos+1);
00877     vector<string> sources;
00878     while( sources_str.find(',') != string::npos ) {
00879       size_t comma_pos = sources_str.find_first_of(',');
00880       sources.push_back(sources_str.substr(0,comma_pos));
00881       sources_str.erase(0,comma_pos+1);
00882     }
00883     sources.push_back(sources_str);// leftover to get the last source
00884     for( int i=0;i<sources.size();i++ ) {
00885       string current_source = sources[i];
00886       size_t doubledot_pos = current_source.find(':');
00887       if( doubledot_pos == string::npos ) {
00888         cerr << "Cannot read from " << nfile.c_str() 
00889              << " line " << nLine << endl;
00890         return false;
00891       }
00892       string source_id = current_source.substr(0,doubledot_pos);
00893       string source_weight = current_source.substr(doubledot_pos+1);
00894       int source = atoi(source_id.c_str());
00895       double weight = atof(source_weight.c_str());
00896       if( source<=0 || source>nNodes_ ) {
00897         cerr << "Unable to read source node from "
00898              << nfile.c_str() << " on line " << nLine << endl;
00899         return false;
00900       }
00901       source--;// offset by 1 to start numbering from 0 instead of 1
00902       // build link
00903       linkSource_[link] = source;
00904       linkWeight_[link] = weight;
00905       nodeNInputLinks_[target]++;
00906       // increment link
00907       link++;
00908     }
00909     if( link == nLinks_ ) {
00910       success = true;
00911       break;
00912     }
00913   }
00914 
00915   // exit
00916   if( success ) {
00917     configured_ = true;
00918     initialized_ = true;
00919   }
00920   return success;
00921 }
00922 
00923 
00924 bool SprStdBackprop::readSPR(const char* netfile)
00925 {
00926   // sanity check and init
00927   if( 0 == netfile ) return false;
00928   string nfile = netfile;
00929 
00930   // open file
00931   ifstream file(nfile.c_str());
00932   if( !file ) {
00933     cerr << "Unable to open file " << nfile.c_str() << endl;
00934     return false;
00935   }
00936 
00937   // read the file
00938   unsigned skipLines = 0;
00939   return this->resumeReadSPR(nfile.c_str(),file,skipLines);
00940 }
00941 
00942 bool SprStdBackprop::resumeReadSPR(const char* netfile,
00943                                    std::ifstream& file, 
00944                                    unsigned& skipLines)
00945 {
00946   // init
00947   unsigned& nLine = skipLines;
00948   structure_ = "Unknown";
00949   configured_ = false;
00950   initialized_ = false;
00951   string nfile = netfile;
00952 
00953   // read header
00954   string line;
00955   for( int i=0;i<2;i++ ) {
00956     nLine++;
00957     if( !getline(file,line) ) {
00958       cerr << "Unable to read line " << nLine 
00959            << " from " << nfile.c_str() << endl;
00960       return false;
00961     }
00962   }
00963 
00964   // read the cut
00965   string dummy;
00966   nLine++;
00967   if( !getline(file,line) ) {
00968     cerr << "Unable to read line " << nLine 
00969          << " from " << nfile.c_str() << endl;
00970     return false;
00971   }
00972   istringstream istcut(line);
00973   istcut >> dummy;
00974   int nCut = 0;
00975   istcut >> nCut;
00976   cut_.clear();
00977   double low(0), high(0);
00978   for( int i=0;i<nCut;i++ ) {
00979     istcut >> low >> high;
00980     cut_.push_back(SprInterval(low,high));
00981   }
00982 
00983   // read number of nodes
00984   nLine++;
00985   if( !getline(file,line) ) {
00986     cerr << "Unable to read line " << nLine 
00987          << " from " << nfile.c_str() << endl;
00988     return false;
00989   }
00990   istringstream istNnodes(line);
00991   istNnodes >> dummy >> nNodes_;
00992   if( nNodes_ <= 0 ) {
00993     cerr << "Rean an invalid number of NN nodes: " << nNodes_ << endl;
00994     return false;
00995   }
00996   
00997   // init nodes
00998   nodeType_.clear(); nodeType_.resize(nNodes_,SprNNDefs::INPUT);
00999   nodeActFun_.clear(); nodeActFun_.resize(nNodes_,SprNNDefs::ID);
01000   nodeAct_.clear(); nodeAct_.resize(nNodes_,0);
01001   nodeOut_.clear(); nodeOut_.resize(nNodes_,0);
01002   nodeNInputLinks_.clear(); nodeNInputLinks_.resize(nNodes_,0);
01003   nodeFirstInputLink_.clear(); nodeFirstInputLink_.resize(nNodes_,-1);
01004   nodeBias_.clear(); nodeBias_.resize(nNodes_,0);
01005 
01006   // read nodes
01007   for( int node=0;node<nNodes_;node++ ) {
01008     nLine++;
01009     if( !getline(file,line) ) {
01010       cerr << "Unable to read line " << nLine 
01011            << " from " << nfile.c_str() << endl;
01012       return false;
01013     }
01014     istringstream istnode(line);
01015     int index = -1;
01016     istnode >> index;
01017     if( index != node ) {
01018       cerr << "Incorrect node number on line " << nLine
01019            << ": Expect " << node << " Actual " << index << endl;
01020       return false;
01021     }
01022     istnode >> dummy;
01023     char nodeType;
01024     istnode >> nodeType;
01025     switch( nodeType )
01026       {
01027       case 'I' :
01028         nodeType_[node] = SprNNDefs::INPUT;
01029         break;
01030       case 'H' :
01031         nodeType_[node] = SprNNDefs::HIDDEN;
01032         break;
01033       case 'O' :
01034         nodeType_[node] = SprNNDefs::OUTPUT;
01035         break;
01036       default :
01037         cerr << "Unknown node type on line " << nLine 
01038              << " in " << nfile.c_str() << endl;
01039         return false;
01040       }
01041     istnode >> dummy;
01042     int actFun = 0;
01043     istnode >> actFun;
01044     switch( actFun )
01045       {
01046       case 1 :
01047         nodeActFun_[node] = SprNNDefs::ID;
01048         break;
01049       case 2 :
01050         nodeActFun_[node] = SprNNDefs::LOGISTIC;
01051         break;
01052       default :
01053         cerr << "Unknown activation function on line " << nLine
01054              << " in " << nfile.c_str() << endl;
01055         return false;
01056       }
01057     istnode >> dummy;
01058     istnode >> nodeNInputLinks_[node];
01059     istnode >> dummy;
01060     istnode >> nodeFirstInputLink_[node];
01061     istnode >> dummy;
01062     istnode >> nodeBias_[node];
01063   }// nodes done
01064 
01065   // read number of nodes
01066   nLine++;
01067   if( !getline(file,line) ) {
01068     cerr << "Unable to read line " << nLine 
01069          << " from " << nfile.c_str() << endl;
01070     return false;
01071   }
01072   istringstream istNlinks(line);
01073   istNlinks >> dummy >> nLinks_;
01074   if( nLinks_ <= 0 ) {
01075     cerr << "Rean an invalid number of NN links: " << nLinks_ << endl;
01076     return false;
01077   }
01078   
01079   // init links
01080   linkSource_.clear(); linkSource_.resize(nLinks_,0);
01081   linkWeight_.clear(); linkWeight_.resize(nLinks_,0);
01082 
01083   // read links
01084   for( int link=0;link<nLinks_;link++ ) {
01085     nLine++;
01086     if( !getline(file,line) ) {
01087       cerr << "Unable to read line " << nLine 
01088            << " from " << nfile.c_str() << endl;
01089       return false;
01090     }
01091     istringstream istlink(line);
01092     int index = -1;
01093     istlink >> index;
01094     if( index != link ) {
01095       cerr << "Incorrect link number on line " << nLine
01096            << ": Expect " << link << " Actual " << index << endl;
01097       return false;
01098     }
01099     istlink >> dummy;
01100     istlink >> linkSource_[link];
01101     istlink >> dummy;
01102     istlink >> linkWeight_[link];
01103   }// links done
01104 
01105   // exit
01106   configured_ = true;
01107   initialized_ = true;
01108   return true;
01109 }

Generated on Tue Jun 9 17:42:03 2009 for CMSSW by  doxygen 1.5.4