00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprStdBackprop.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprAverageLoss.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprStringParser.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00012
00013 #include <stdio.h>
00014 #include <cmath>
00015 #include <iomanip>
00016 #include <sstream>
00017 #include <utility>
00018 #include <cassert>
00019
00020 using namespace std;
00021
00022
00023 SprStdBackprop::~SprStdBackprop()
00024 {
00025 if( ownLoss_ ) {
00026 delete loss_;
00027 loss_ = 0;
00028 ownLoss_ = false;
00029 }
00030 }
00031
00032 SprStdBackprop::SprStdBackprop(SprAbsFilter* data)
00033 :
00034 SprAbsClassifier(data),
00035 structure_(),
00036 cls0_(0),
00037 cls1_(1),
00038 cycles_(0),
00039 eta_(0.1),
00040 configured_(false),
00041 initialized_(false),
00042 initEta_(0.1),
00043 initPoints_(data->size()),
00044 rndm_(),
00045 permu_(data->size()),
00046 allowPermu_(true),
00047 nNodes_(0),
00048 nLinks_(0),
00049 nodeType_(),
00050 nodeActFun_(),
00051 nodeAct_(),
00052 nodeOut_(),
00053 nodeNInputLinks_(),
00054 nodeFirstInputLink_(),
00055 linkSource_(),
00056 nodeBias_(),
00057 linkWeight_(),
00058 cut_(SprUtils::lowerBound(0.5)),
00059 valData_(0),
00060 valPrint_(0),
00061 loss_(0),
00062 ownLoss_(false),
00063 initialDataWeights_()
00064 {
00065 this->setClasses();
00066 }
00067
00068
00069 SprStdBackprop::SprStdBackprop(SprAbsFilter* data,
00070 unsigned cycles,
00071 double eta)
00072 :
00073 SprAbsClassifier(data),
00074 structure_(),
00075 cls0_(0),
00076 cls1_(1),
00077 cycles_(cycles),
00078 eta_(eta),
00079 configured_(false),
00080 initialized_(false),
00081 initEta_(0.1),
00082 initPoints_(data->size()),
00083 rndm_(),
00084 permu_(data->size()),
00085 allowPermu_(true),
00086 nNodes_(0),
00087 nLinks_(0),
00088 nodeType_(),
00089 nodeActFun_(),
00090 nodeAct_(),
00091 nodeOut_(),
00092 nodeNInputLinks_(),
00093 nodeFirstInputLink_(),
00094 linkSource_(),
00095 nodeBias_(),
00096 linkWeight_(),
00097 cut_(SprUtils::lowerBound(0.5)),
00098 valData_(0),
00099 valPrint_(0),
00100 loss_(0),
00101 ownLoss_(false),
00102 initialDataWeights_()
00103 {
00104 this->setClasses();
00105 cout << "StdBackprop initialized with classes " << cls0_ << " " << cls1_
00106 << " nCycles=" << cycles_ << " LearningRate=" << eta_ << endl;
00107 }
00108
00109
00110 SprStdBackprop::SprStdBackprop(SprAbsFilter* data,
00111 const char* structure,
00112 unsigned cycles,
00113 double eta)
00114 :
00115 SprAbsClassifier(data),
00116 structure_(structure),
00117 cls0_(0),
00118 cls1_(1),
00119 cycles_(cycles),
00120 eta_(eta),
00121 configured_(false),
00122 initialized_(false),
00123 initEta_(0.1),
00124 initPoints_(data->size()),
00125 rndm_(),
00126 permu_(data->size()),
00127 allowPermu_(true),
00128 nNodes_(0),
00129 nLinks_(0),
00130 nodeType_(),
00131 nodeActFun_(),
00132 nodeAct_(),
00133 nodeOut_(),
00134 nodeNInputLinks_(),
00135 nodeFirstInputLink_(),
00136 linkSource_(),
00137 nodeBias_(),
00138 linkWeight_(),
00139 cut_(SprUtils::lowerBound(0.5)),
00140 valData_(0),
00141 valPrint_(0),
00142 loss_(0),
00143 ownLoss_(false),
00144 initialDataWeights_()
00145 {
00146 this->setClasses();
00147 bool status = this->createNet();
00148 assert( status );
00149 cout << "StdBackprop initialized with classes " << cls0_ << " " << cls1_
00150 << " nCycles=" << cycles_ << " structure=" << structure_.c_str()
00151 << " LearningRate=" << eta_ << endl;
00152 }
00153
00154
00155 SprTrainedStdBackprop* SprStdBackprop::makeTrained() const
00156 {
00157 SprTrainedStdBackprop* t = new SprTrainedStdBackprop(structure_.c_str(),
00158 nodeType_,nodeActFun_,
00159 nodeNInputLinks_,
00160 nodeFirstInputLink_,
00161 linkSource_,nodeBias_,
00162 linkWeight_);
00163 t->setCut(cut_);
00164
00165
00166 vector<string> vars;
00167 data_->vars(vars);
00168 t->setVars(vars);
00169
00170
00171 return t;
00172 }
00173
00174
00175 bool SprStdBackprop::createNet()
00176 {
00177
00178 configured_ = false;
00179
00180
00181 if( structure_.empty() ) {
00182 cerr << "No network structure specified. Exiting." << endl;
00183 return false;
00184 }
00185
00186
00187 vector<vector<int> > layers;
00188 SprStringParser::parseToInts(structure_.c_str(),layers);
00189
00190
00191 if( layers.size() < 3 ) {
00192 cerr << "Not enough layers in the neural net: " << layers.size()
00193 << " for structure " << structure_.c_str() << endl;
00194 return false;
00195 }
00196 if( layers[0].size()!=1 || layers[0][0]!=data_->dim() ) {
00197 cerr << "Size of the input layer " << layers[0][0]
00198 << " must be equal to the dimensionality of input data "
00199 << data_->dim() << endl;
00200 return false;
00201 }
00202 for( int i=1;i<layers.size()-1;i++ ) {
00203 if( layers[i].size()!=1 || layers[i][0]<=0 ) {
00204 cerr << "Error in specifying hidden layer " << i << endl;
00205 return false;
00206 }
00207 }
00208 if( layers[layers.size()-1].size()!=1 || layers[layers.size()-1][0]!=1 ) {
00209 cerr << "This NN implementation can only handle "
00210 << "one node in the output layer." << endl;
00211 return false;
00212 }
00213
00214
00215 nNodes_ = 0;
00216 for( int i=0;i<layers.size();i++ ) nNodes_ += layers[i][0];
00217 nodeType_.clear(); nodeType_.resize(nNodes_,SprNNDefs::INPUT);
00218 nodeActFun_.clear(); nodeActFun_.resize(nNodes_,SprNNDefs::ID);
00219 nodeAct_.clear(); nodeAct_.resize(nNodes_,0);
00220 nodeOut_.clear(); nodeOut_.resize(nNodes_,0);
00221 nodeNInputLinks_.clear(); nodeNInputLinks_.resize(nNodes_,0);
00222 nodeFirstInputLink_.clear(); nodeFirstInputLink_.resize(nNodes_,-1);
00223 nodeBias_.clear(); nodeBias_.resize(nNodes_,0);
00224 int index = 0;
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240 index = layers[0][0];
00241 int firstLink = 0;
00242 linkSource_.clear();
00243 int nstart(0), nend(0);
00244 for( int i=1;i<layers.size()-1;i++ ) {
00245 nstart = nend;
00246 nend += layers[i-1][0];
00247 for( int j=0;j<layers[i][0];j++ ) {
00248 nodeType_[index] = SprNNDefs::HIDDEN;
00249 nodeActFun_[index] = SprNNDefs::LOGISTIC;
00250 nodeNInputLinks_[index] = layers[i-1][0];
00251 nodeFirstInputLink_[index] = firstLink;
00252 firstLink += layers[i-1][0];
00253 index++;
00254 for( int n=nstart;n<nend;n++ ) linkSource_.push_back(n);
00255 }
00256 }
00257
00258
00259 assert( index == (nNodes_-1) );
00260 nodeType_[index] = SprNNDefs::OUTPUT;
00261 nodeActFun_[index] = SprNNDefs::LOGISTIC;
00262 nodeNInputLinks_[index] = layers[layers.size()-2][0];
00263 nodeFirstInputLink_[index] = firstLink;
00264 nstart = nend;
00265 nend += layers[layers.size()-2][0];
00266 for( int n=nstart;n<nend;n++ ) linkSource_.push_back(n);
00267
00268
00269 nLinks_ = linkSource_.size();
00270 linkWeight_.clear(); linkWeight_.resize(nLinks_,0);
00271
00272
00273 configured_ = true;
00274 return true;
00275 }
00276
00277
00278 bool SprStdBackprop::init(double eta, unsigned nPoints)
00279 {
00280 if( initialized_ ) return true;
00281 initEta_ = eta;
00282 initPoints_ = nPoints;
00283 unsigned valPrint = valPrint_;
00284 valPrint_ = 0;
00285 initialized_ = this->doTrain(initPoints_,1,initEta_,true,1);
00286 valPrint_ = valPrint;
00287 return initialized_;
00288 }
00289
00290
00291 bool SprStdBackprop::train(int verbose)
00292 {
00293
00294 if( cycles_ == 0 ) {
00295 cout << "No training cycles for neural net requested. "
00296 << "Will exit without training." << endl;
00297 return true;
00298 }
00299 if( !configured_ ) {
00300 cerr << "Neural net configuration not specified." << endl;
00301 return false;
00302 }
00303
00304
00305 if( !initialized_ ) {
00306 if( verbose > 0 ) {
00307 cout << "Initializing network with learning rate " << initEta_
00308 << " and number of points for initialization " << initPoints_
00309 << endl;
00310 }
00311 if( !this->init(initEta_,initPoints_) ) {
00312 cerr << "Unable to initialize network." << endl;
00313 return false;
00314 }
00315 if( verbose > 0 )
00316 cout << "Neural net initialized." << endl;
00317 }
00318
00319
00320 return this->doTrain(data_->size(),cycles_,eta_,false,verbose);
00321 }
00322
00323
00324 bool SprStdBackprop::doTrain(unsigned nPoints, unsigned nCycles,
00325 double eta, bool randomizeEta, int verbose)
00326 {
00327
00328 data_->weights(initialDataWeights_);
00329 vector<SprClass> classes(2);
00330 classes[0] = cls0_; classes[1] = cls1_;
00331 double wtot = data_->ptsInClass(cls0_) + data_->ptsInClass(cls1_);
00332 data_->normalizeWeights(classes,wtot);
00333
00334
00335 unsigned size = data_->size();
00336 if( nPoints==0 || nPoints>size ) {
00337 if( verbose > 1 ) {
00338 cout << "Resetting the number of training points "
00339 << "to the max number of points available." << endl;
00340 }
00341 nPoints = size;
00342 }
00343 vector<unsigned> indices;
00344 if( allowPermu_ ) {
00345 if( !permu_.sequence(indices) ) {
00346 cerr << "Unable to permute input indices for training." << endl;
00347 return this->prepareExit(false);
00348 }
00349 }
00350 else {
00351 for( unsigned i=0;i<nPoints;i++ ) indices.push_back(i);
00352 }
00353
00354
00355 if( valPrint_!=0 ) {
00356 if( !this->printValidation(0) ) {
00357 cerr << "Unable to print out validation data." << endl;
00358 return this->prepareExit(false);
00359 }
00360 }
00361
00362
00363 for( int ncycle=1;ncycle<=nCycles;ncycle++ ) {
00364
00365 if( verbose > 0 ) {
00366 if( ncycle%10 == 0 )
00367 cout << "Training neural net at cycle " << ncycle << endl;
00368 }
00369
00370
00371 for( int i=0;i<nPoints;i++ ) {
00372 unsigned ipt = indices[i];
00373 const SprPoint* p = (*data_)[ipt];
00374 int cls = -1;
00375 if( p->class_ == cls0_ )
00376 cls = 0;
00377 else if( p->class_ == cls1_ )
00378 cls = 1;
00379 else
00380 continue;
00381
00382
00383 double output = this->forward(p->x_);
00384
00385
00386 double w = data_->w(ipt);
00387 vector<double> etaV(nLinks_+1,w*eta);
00388 if( randomizeEta ) {
00389 double* r = new double [nLinks_+1];
00390 rndm_.sequence(r,nLinks_);
00391 for( int j=0;j<=nLinks_;j++ ) etaV[j] = eta*r[j];
00392 delete [] r;
00393 }
00394
00395
00396 if( !this->backward(cls,output,etaV) ) {
00397 cerr << "Unable to backward-propagate at cycle " << ncycle << endl;
00398 return this->prepareExit(false);
00399 }
00400 }
00401
00402
00403 if( valPrint_!=0 && (ncycle%valPrint_)==0 ) {
00404 if( !this->printValidation(ncycle) ) {
00405 cerr << "Unable to print out validation data." << endl;
00406 return this->prepareExit(false);
00407 }
00408 }
00409 }
00410
00411
00412 return this->prepareExit(true);
00413 }
00414
00415
00416 double SprStdBackprop::forward(const std::vector<double>& v)
00417 {
00418
00419 nodeOut_.clear(); nodeOut_.resize(nNodes_,0);
00420 int d = 0;
00421 for( int i=0;i<nNodes_;i++ ) {
00422 if( nodeType_[i] == SprNNDefs::INPUT )
00423 nodeOut_[i] = v[d++];
00424 else
00425 break;
00426 }
00427
00428
00429 for( int i=0;i<nNodes_;i++ ) {
00430 nodeAct_[i] = 0;
00431 if( nodeNInputLinks_[i] > 0 ) {
00432 for( int j=nodeFirstInputLink_[i];
00433 j<nodeFirstInputLink_[i]+nodeNInputLinks_[i];j++ ) {
00434 nodeAct_[i] += nodeOut_[linkSource_[j]] * linkWeight_[j];
00435 }
00436 nodeOut_[i] = this->activate(nodeAct_[i]+nodeBias_[i],nodeActFun_[i]);
00437 }
00438 }
00439
00440
00441 return nodeOut_[nNodes_-1];
00442 }
00443
00444
00445 bool SprStdBackprop::backward(int cls, double output,
00446 const std::vector<double>& etaV)
00447 {
00448
00449 vector<double> tempLinkWeight(linkWeight_);
00450 vector<double> tempNodeBias(nodeBias_);
00451
00452
00453 vector<double> nodeGradient(nNodes_,0);
00454
00455
00456 nodeGradient[nNodes_-1] = (double(cls)-output) *
00457 this->act_deriv(nodeAct_[nNodes_-1]+nodeBias_[nNodes_-1],
00458 nodeActFun_[nNodes_-1]);
00459 nodeBias_[nNodes_-1] += etaV[nLinks_] * nodeGradient[nNodes_-1];
00460
00461
00462 for( int target=nNodes_-1;target>=0;target-- ) {
00463 if( nodeNInputLinks_[target] > 0 ) {
00464 for( int link=nodeFirstInputLink_[target];
00465 link<nodeFirstInputLink_[target]+nodeNInputLinks_[target];
00466 link++ ) {
00467 int source = linkSource_[link];
00468 linkWeight_[link] += etaV[link]
00469 * nodeGradient[target] * nodeOut_[source];
00470 if( nodeType_[source] == SprNNDefs::HIDDEN ) {
00471 nodeGradient[source] +=
00472 this->act_deriv(nodeAct_[source]+tempNodeBias[source],
00473 nodeActFun_[source])
00474 * tempLinkWeight[link] * nodeGradient[target];
00475 nodeBias_[source] += etaV[link] * nodeGradient[source];
00476 }
00477 }
00478 }
00479 }
00480
00481
00482 return true;
00483 }
00484
00485
00486 bool SprStdBackprop::reset()
00487 {
00488 initialized_ = false;
00489 nodeBias_.clear(); nodeBias_.resize(nNodes_,0);
00490 nodeAct_.clear(); nodeAct_.resize(nNodes_,0);
00491 nodeOut_.clear(); nodeOut_.resize(nNodes_,0);
00492 linkWeight_.clear(); linkWeight_.resize(nLinks_,0);
00493 return true;
00494 }
00495
00496
00497 bool SprStdBackprop::setData(SprAbsFilter* data)
00498 {
00499 assert( data != 0 );
00500 data_ = data;
00501 return this->reset();
00502 }
00503
00504
00505 void SprStdBackprop::print(std::ostream& os) const
00506 {
00507 os << "Trained StdBackprop with configuration "
00508 << structure_.c_str() << " " << SprVersion << endl;
00509 os << "Activation functions: Identity=1, Logistic=2" << endl;
00510 os << "Cut: " << cut_.size();
00511 for( int i=0;i<cut_.size();i++ )
00512 os << " " << cut_[i].first << " " << cut_[i].second;
00513 os << endl;
00514 os << "Nodes: " << nNodes_ << endl;
00515 for( int i=0;i<nNodes_;i++ ) {
00516 char nodeType;
00517 switch( nodeType_[i] )
00518 {
00519 case SprNNDefs::INPUT :
00520 nodeType = 'I';
00521 break;
00522 case SprNNDefs::HIDDEN :
00523 nodeType = 'H';
00524 break;
00525 case SprNNDefs::OUTPUT :
00526 nodeType = 'O';
00527 break;
00528 }
00529 int actFun = 0;
00530 switch( nodeActFun_[i] )
00531 {
00532 case SprNNDefs::ID :
00533 actFun = 1;
00534 break;
00535 case SprNNDefs::LOGISTIC :
00536 actFun = 2;
00537 break;
00538 }
00539 os << setw(6) << i
00540 << " Type: " << nodeType
00541 << " ActFunction: " << actFun
00542 << " NInputLinks: " << setw(6) << nodeNInputLinks_[i]
00543 << " FirstInputLink: " << setw(6) << nodeFirstInputLink_[i]
00544 << " Bias: " << nodeBias_[i]
00545 << endl;
00546 }
00547 os << "Links: " << nLinks_ << endl;
00548 for( int i=0;i<nLinks_;i++ ) {
00549 os << setw(6) << i
00550 << " Source: " << setw(6) << linkSource_[i]
00551 << " Weight: " << linkWeight_[i]
00552 << endl;
00553 }
00554 }
00555
00556
00557 void SprStdBackprop::setClasses()
00558 {
00559 vector<SprClass> classes;
00560 data_->classes(classes);
00561 int size = classes.size();
00562 if( size > 0 ) cls0_ = classes[0];
00563 if( size > 1 ) cls1_ = classes[1];
00564 cout << "Classes for StdBackprop are set to "
00565 << cls0_ << " " << cls1_ << endl;
00566 }
00567
00568
00569 bool SprStdBackprop::setValidation(const SprAbsFilter* valData,
00570 unsigned valPrint,
00571 SprAverageLoss* loss)
00572 {
00573
00574 valData_ = valData;
00575 valPrint_ = valPrint;
00576
00577
00578 loss_ = loss;
00579 ownLoss_ = false;
00580 if( loss_ == 0 ) {
00581 loss_ = new SprAverageLoss(&SprLoss::quadratic);
00582 ownLoss_ = true;
00583 }
00584
00585
00586 return true;
00587 }
00588
00589
00590 bool SprStdBackprop::printValidation(unsigned cycle)
00591 {
00592
00593 assert( loss_ != 0 );
00594 loss_->reset();
00595
00596
00597 SprTrainedStdBackprop* t = this->makeTrained();
00598
00599
00600 for( int i=0;i<valData_->size();i++ ) {
00601 const SprPoint* p = (*valData_)[i];
00602 double r = t->response(p->x_);
00603 double w = valData_->w(i);
00604 if( p->class_!=cls0_ && p->class_!=cls1_ ) w = 0;
00605 if( p->class_ == cls0_ )
00606 loss_->update(0,r,w);
00607 else if( p->class_ == cls1_ )
00608 loss_->update(1,r,w);
00609 }
00610
00611
00612 cout << "Validation Loss=" << loss_->value()
00613 << " at cycle " << cycle << endl;
00614
00615
00616 return true;
00617 }
00618
00619
00620 double SprStdBackprop::activate(double x, SprNNDefs::ActFun f) const
00621 {
00622 switch (f)
00623 {
00624 case SprNNDefs::ID :
00625 return x;
00626 break;
00627 case SprNNDefs::LOGISTIC :
00628 return SprTransformation::logit(x);
00629 break;
00630 default :
00631 cerr << "Unknown activation function "
00632 << f << " in SprTrainedStdBackprop::activate" << endl;
00633 return 0;
00634 }
00635 return 0;
00636 }
00637
00638
00639 double SprStdBackprop::act_deriv(double x, SprNNDefs::ActFun f) const
00640 {
00641 switch (f)
00642 {
00643 case SprNNDefs::ID :
00644 return 1;
00645 break;
00646 case SprNNDefs::LOGISTIC :
00647 return SprTransformation::logit_deriv(x);
00648 break;
00649 default :
00650 cerr << "Unknown activation function "
00651 << f << " in SprTrainedStdBackprop::activate" << endl;
00652 return 0;
00653 }
00654 return 0;
00655 }
00656
00657
00658 bool SprStdBackprop::prepareExit(bool status)
00659 {
00660 data_->setWeights(initialDataWeights_);
00661 return status;
00662 }
00663
00664
00665 bool SprStdBackprop::readSNNS(const char* netfile)
00666 {
00667
00668 if( 0 == netfile ) return false;
00669 structure_ = "Unknown";
00670 configured_ = false;
00671 initialized_ = false;
00672 string nfile = netfile;
00673 bool success = false;
00674
00675
00676 ifstream file(nfile.c_str());
00677 if( !file ) {
00678 cerr << "Unable to open file " << nfile.c_str() << endl;
00679 return false;
00680 }
00681
00682
00683 string line;
00684 unsigned nLine = 0;
00685 nLine++;
00686 nNodes_ = 0;
00687 while( getline(file,line) ) {
00688 const char* searchfor = "no. of units :";
00689 size_t pos = line.find(searchfor);
00690 if( pos != string::npos ) {
00691 line.erase(0,pos+strlen(searchfor)+1);
00692 istringstream istnodes(line);
00693 istnodes >> nNodes_;
00694 break;
00695 }
00696 nLine++;
00697 }
00698 if( nNodes_ <= 0 ) {
00699 cerr << "Can't find units line in file " << nfile.c_str() << endl;
00700 return false;
00701 }
00702 nLine++;
00703 if( !getline(file,line) ) {
00704 cerr << "Cannot read from " << nfile.c_str() << " line " << nLine << endl;
00705 return false;
00706 }
00707 nLinks_ = 0;
00708 const char* searchfor = "no. of connections :";
00709 size_t pos = line.find(searchfor);
00710 if( pos != string::npos ) {
00711 line.erase(0,pos+strlen(searchfor)+1);
00712 istringstream istconns(line);
00713 istconns >> nLinks_;
00714 }
00715 if( nLinks_ <= 0 ) {
00716 cerr << "Can't find connections line in file " << nfile.c_str() << endl;
00717 return false;
00718 }
00719
00720
00721
00722 nodeType_.clear(); nodeType_.resize(nNodes_,SprNNDefs::INPUT);
00723 nodeActFun_.clear(); nodeActFun_.resize(nNodes_,SprNNDefs::ID);
00724 nodeAct_.clear(); nodeAct_.resize(nNodes_,0);
00725 nodeOut_.clear(); nodeOut_.resize(nNodes_,0);
00726 nodeNInputLinks_.clear(); nodeNInputLinks_.resize(nNodes_,0);
00727 nodeFirstInputLink_.clear(); nodeFirstInputLink_.resize(nNodes_,-1);
00728 nodeBias_.clear(); nodeBias_.resize(nNodes_,0);
00729 linkSource_.clear(); linkSource_.resize(nLinks_,0);
00730 linkWeight_.clear(); linkWeight_.resize(nLinks_,0);
00731
00732
00733
00734
00735
00736
00737
00738 nLine++;
00739 bool found = false;
00740 while( getline(file,line) ) {
00741 size_t pos = line.find("unit definition section :");
00742 if( pos != string::npos ) {
00743 found = true;
00744 break;
00745 }
00746 nLine++;
00747 }
00748 if( !found ) {
00749 cerr << "Can't find unit definition section in file "
00750 << nfile.c_str() << endl;
00751 return false;
00752 }
00753
00754 for( int i=0;i<3;i++ ) {
00755 nLine++;
00756 if( !getline(file,line) ) {
00757 cerr << "Cannot read from " << nfile.c_str()
00758 << " line " << nLine << endl;
00759 return false;
00760 }
00761 }
00762
00763 unsigned nOutput = 0;
00764 for( int node=0;node<nNodes_;node++ ) {
00765 nLine++;
00766 if( !getline(file,line) ) {
00767 cerr << "Cannot read from " << nfile.c_str()
00768 << " line " << nLine << endl;
00769 return false;
00770 }
00771 istringstream istnode(line);
00772 int id = 0;
00773 istnode >> id;
00774 if( id != (node+1) ) {
00775 cerr << "Node ID does not match on line " << nLine << endl;
00776 return false;
00777 }
00778 char c;
00779 double dummy;
00780 for( int i=0;i<3;i++ ) istnode >> c;
00781 istnode >> dummy >> c >> nodeBias_[node] >> c;
00782 istnode >> c;
00783 switch( c )
00784 {
00785 case 'i' :
00786 nodeType_[node] = SprNNDefs::INPUT;
00787 nodeActFun_[node] = SprNNDefs::ID;
00788 break;
00789 case 'h' :
00790 nodeType_[node] = SprNNDefs::HIDDEN;
00791 nodeActFun_[node] = SprNNDefs::LOGISTIC;
00792 break;
00793 case 'o' :
00794 nodeType_[node] = SprNNDefs::OUTPUT;
00795 nodeActFun_[node] = SprNNDefs::LOGISTIC;
00796 nOutput++;
00797 break;
00798 default :
00799 cerr << "Unknown node type on line " << nLine << endl;
00800 return false;
00801 }
00802 }
00803 if( nOutput > 1 ) {
00804 cerr << "More than one output node cannot be handled "
00805 << "by this implementation" << endl;
00806 return false;
00807 }
00808
00809
00810
00811
00812
00813 nLine++;
00814 found = false;
00815 while( getline(file,line) ) {
00816 size_t pos = line.find("connection definition section :");
00817 if( pos != string::npos ) {
00818 found = true;
00819 break;
00820 }
00821 nLine++;
00822 }
00823 if( !found ) {
00824 cerr << "Can't find connection definition section in file "
00825 << nfile.c_str() << endl;
00826 return false;
00827 }
00828
00829 for( int i=0;i<3;i++ ) {
00830 nLine++;
00831 if( !getline(file,line) ) {
00832 cerr << "Cannot read from " << nfile.c_str()
00833 << " line " << nLine << endl;
00834 return false;
00835 }
00836 }
00837
00838 int link = 0;
00839 string prevLine;
00840 while( getline(file,line) ) {
00841 nLine++;
00842
00843 if( line.at(line.find_last_not_of(' ')) == ',' ) {
00844 prevLine = line;
00845 continue;
00846 }
00847 line = prevLine+line;
00848 prevLine = "";
00849
00850 size_t separ_pos = line.find_first_of('|');
00851 if( separ_pos == string::npos ) {
00852 cerr << "Cannot read from " << nfile.c_str()
00853 << " line " << nLine << endl;
00854 return false;
00855 }
00856 string target_str = line.substr(0,separ_pos);
00857 line.erase(0,separ_pos+1);
00858 int target = atoi(target_str.c_str());
00859 if( target<=0 || target>nNodes_ ) {
00860 cerr << "Unable to read target node from "
00861 << nfile.c_str() << " on line " << nLine
00862 << " : nNodes=" << nNodes_ << " target=" << target << endl;
00863 return false;
00864 }
00865 target--;
00866
00867 nodeFirstInputLink_[target] = link;
00868
00869 separ_pos = line.find_first_of('|');
00870 if( separ_pos == string::npos ) {
00871 cerr << "Cannot read from " << nfile.c_str()
00872 << " line " << nLine << endl;
00873 return false;
00874 }
00875
00876 string sources_str = line.substr(separ_pos+1);
00877 vector<string> sources;
00878 while( sources_str.find(',') != string::npos ) {
00879 size_t comma_pos = sources_str.find_first_of(',');
00880 sources.push_back(sources_str.substr(0,comma_pos));
00881 sources_str.erase(0,comma_pos+1);
00882 }
00883 sources.push_back(sources_str);
00884 for( int i=0;i<sources.size();i++ ) {
00885 string current_source = sources[i];
00886 size_t doubledot_pos = current_source.find(':');
00887 if( doubledot_pos == string::npos ) {
00888 cerr << "Cannot read from " << nfile.c_str()
00889 << " line " << nLine << endl;
00890 return false;
00891 }
00892 string source_id = current_source.substr(0,doubledot_pos);
00893 string source_weight = current_source.substr(doubledot_pos+1);
00894 int source = atoi(source_id.c_str());
00895 double weight = atof(source_weight.c_str());
00896 if( source<=0 || source>nNodes_ ) {
00897 cerr << "Unable to read source node from "
00898 << nfile.c_str() << " on line " << nLine << endl;
00899 return false;
00900 }
00901 source--;
00902
00903 linkSource_[link] = source;
00904 linkWeight_[link] = weight;
00905 nodeNInputLinks_[target]++;
00906
00907 link++;
00908 }
00909 if( link == nLinks_ ) {
00910 success = true;
00911 break;
00912 }
00913 }
00914
00915
00916 if( success ) {
00917 configured_ = true;
00918 initialized_ = true;
00919 }
00920 return success;
00921 }
00922
00923
00924 bool SprStdBackprop::readSPR(const char* netfile)
00925 {
00926
00927 if( 0 == netfile ) return false;
00928 string nfile = netfile;
00929
00930
00931 ifstream file(nfile.c_str());
00932 if( !file ) {
00933 cerr << "Unable to open file " << nfile.c_str() << endl;
00934 return false;
00935 }
00936
00937
00938 unsigned skipLines = 0;
00939 return this->resumeReadSPR(nfile.c_str(),file,skipLines);
00940 }
00941
00942 bool SprStdBackprop::resumeReadSPR(const char* netfile,
00943 std::ifstream& file,
00944 unsigned& skipLines)
00945 {
00946
00947 unsigned& nLine = skipLines;
00948 structure_ = "Unknown";
00949 configured_ = false;
00950 initialized_ = false;
00951 string nfile = netfile;
00952
00953
00954 string line;
00955 for( int i=0;i<2;i++ ) {
00956 nLine++;
00957 if( !getline(file,line) ) {
00958 cerr << "Unable to read line " << nLine
00959 << " from " << nfile.c_str() << endl;
00960 return false;
00961 }
00962 }
00963
00964
00965 string dummy;
00966 nLine++;
00967 if( !getline(file,line) ) {
00968 cerr << "Unable to read line " << nLine
00969 << " from " << nfile.c_str() << endl;
00970 return false;
00971 }
00972 istringstream istcut(line);
00973 istcut >> dummy;
00974 int nCut = 0;
00975 istcut >> nCut;
00976 cut_.clear();
00977 double low(0), high(0);
00978 for( int i=0;i<nCut;i++ ) {
00979 istcut >> low >> high;
00980 cut_.push_back(SprInterval(low,high));
00981 }
00982
00983
00984 nLine++;
00985 if( !getline(file,line) ) {
00986 cerr << "Unable to read line " << nLine
00987 << " from " << nfile.c_str() << endl;
00988 return false;
00989 }
00990 istringstream istNnodes(line);
00991 istNnodes >> dummy >> nNodes_;
00992 if( nNodes_ <= 0 ) {
00993 cerr << "Rean an invalid number of NN nodes: " << nNodes_ << endl;
00994 return false;
00995 }
00996
00997
00998 nodeType_.clear(); nodeType_.resize(nNodes_,SprNNDefs::INPUT);
00999 nodeActFun_.clear(); nodeActFun_.resize(nNodes_,SprNNDefs::ID);
01000 nodeAct_.clear(); nodeAct_.resize(nNodes_,0);
01001 nodeOut_.clear(); nodeOut_.resize(nNodes_,0);
01002 nodeNInputLinks_.clear(); nodeNInputLinks_.resize(nNodes_,0);
01003 nodeFirstInputLink_.clear(); nodeFirstInputLink_.resize(nNodes_,-1);
01004 nodeBias_.clear(); nodeBias_.resize(nNodes_,0);
01005
01006
01007 for( int node=0;node<nNodes_;node++ ) {
01008 nLine++;
01009 if( !getline(file,line) ) {
01010 cerr << "Unable to read line " << nLine
01011 << " from " << nfile.c_str() << endl;
01012 return false;
01013 }
01014 istringstream istnode(line);
01015 int index = -1;
01016 istnode >> index;
01017 if( index != node ) {
01018 cerr << "Incorrect node number on line " << nLine
01019 << ": Expect " << node << " Actual " << index << endl;
01020 return false;
01021 }
01022 istnode >> dummy;
01023 char nodeType;
01024 istnode >> nodeType;
01025 switch( nodeType )
01026 {
01027 case 'I' :
01028 nodeType_[node] = SprNNDefs::INPUT;
01029 break;
01030 case 'H' :
01031 nodeType_[node] = SprNNDefs::HIDDEN;
01032 break;
01033 case 'O' :
01034 nodeType_[node] = SprNNDefs::OUTPUT;
01035 break;
01036 default :
01037 cerr << "Unknown node type on line " << nLine
01038 << " in " << nfile.c_str() << endl;
01039 return false;
01040 }
01041 istnode >> dummy;
01042 int actFun = 0;
01043 istnode >> actFun;
01044 switch( actFun )
01045 {
01046 case 1 :
01047 nodeActFun_[node] = SprNNDefs::ID;
01048 break;
01049 case 2 :
01050 nodeActFun_[node] = SprNNDefs::LOGISTIC;
01051 break;
01052 default :
01053 cerr << "Unknown activation function on line " << nLine
01054 << " in " << nfile.c_str() << endl;
01055 return false;
01056 }
01057 istnode >> dummy;
01058 istnode >> nodeNInputLinks_[node];
01059 istnode >> dummy;
01060 istnode >> nodeFirstInputLink_[node];
01061 istnode >> dummy;
01062 istnode >> nodeBias_[node];
01063 }
01064
01065
01066 nLine++;
01067 if( !getline(file,line) ) {
01068 cerr << "Unable to read line " << nLine
01069 << " from " << nfile.c_str() << endl;
01070 return false;
01071 }
01072 istringstream istNlinks(line);
01073 istNlinks >> dummy >> nLinks_;
01074 if( nLinks_ <= 0 ) {
01075 cerr << "Rean an invalid number of NN links: " << nLinks_ << endl;
01076 return false;
01077 }
01078
01079
01080 linkSource_.clear(); linkSource_.resize(nLinks_,0);
01081 linkWeight_.clear(); linkWeight_.resize(nLinks_,0);
01082
01083
01084 for( int link=0;link<nLinks_;link++ ) {
01085 nLine++;
01086 if( !getline(file,line) ) {
01087 cerr << "Unable to read line " << nLine
01088 << " from " << nfile.c_str() << endl;
01089 return false;
01090 }
01091 istringstream istlink(line);
01092 int index = -1;
01093 istlink >> index;
01094 if( index != link ) {
01095 cerr << "Incorrect link number on line " << nLine
01096 << ": Expect " << link << " Actual " << index << endl;
01097 return false;
01098 }
01099 istlink >> dummy;
01100 istlink >> linkSource_[link];
01101 istlink >> dummy;
01102 istlink >> linkWeight_[link];
01103 }
01104
01105
01106 configured_ = true;
01107 initialized_ = true;
01108 return true;
01109 }