00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprArcE4.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprBootstrap.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTwoClassCriterion.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00010
00011 #include <cassert>
00012 #include <cmath>
00013
00014 using namespace std;
00015
00016
00017 SprArcE4::SprArcE4(SprAbsFilter* data,
00018 unsigned cycles, bool discrete)
00019 :
00020 SprBagger(data,cycles,discrete),
00021 initialDataWeights_(),
00022 response_(data->size(),pair<double,double>(0,0))
00023 {
00024 data_->weights(initialDataWeights_);
00025 cout << "ArcE4 initialized." << endl;
00026 }
00027
00028
00029 bool SprArcE4::setData(SprAbsFilter* data)
00030 {
00031
00032 if( !SprBagger::setData(data) ) {
00033 cerr << "Unable to set data for ArcE4." << endl;
00034 return false;
00035 }
00036
00037
00038 data_->weights(initialDataWeights_);
00039
00040
00041 response_.clear();
00042 response_.resize(data_->size(),pair<double,double>(0,0));
00043
00044
00045 return true;
00046 }
00047
00048
00049 bool SprArcE4::train(int verbose)
00050 {
00051
00052 if( cycles_==0 || trainable_.empty() ) {
00053 cout << "ArcE4 will exit without training." << endl;
00054 return this->prepareExit(true);
00055 }
00056
00057
00058 if( cycles_>0 && !trained_.empty() ) {
00059 delete bootstrap_;
00060 bootstrap_ = new SprBootstrap(data_,-1);
00061 assert( bootstrap_ != 0 );
00062 }
00063
00064
00065 assert( data_->size() == response_.size() );
00066 for( int i=0;i<data_->size();i++ ) {
00067 const SprPoint* p = (*data_)[i];
00068 for( int j=0;j<trained_.size();j++ ) {
00069 double& resp = response_[i].first;
00070 double& wresp = response_[i].second;
00071 resp = wresp*resp + trained_[j].first->response(p);
00072 wresp += 1.;
00073 resp /= wresp;
00074 }
00075 }
00076
00077
00078 if( valData_ != 0 ) {
00079
00080 valBeta_.clear();
00081 int vsize = valData_->size();
00082 valBeta_.resize(vsize,0);
00083 int tsize = trained_.size();
00084 for( int i=0;i<vsize;i++ ) {
00085 const SprPoint* p = (*valData_)[i];
00086 if( discrete_ ) {
00087 for( int j=0;j<tsize;j++ )
00088 valBeta_[i] += ( trained_[j].first->accept(p) ? 1 : -1 );
00089 }
00090 else {
00091 for( int j=0;j<tsize;j++ )
00092 valBeta_[i] += trained_[j].first->response(p);
00093 }
00094 if( tsize > 0 ) valBeta_[i] /= tsize;
00095 }
00096
00097
00098 if( valPrint_ > 0 ) {
00099 if( !this->printValidation(0) ) {
00100 cerr << "Unable to print out validation data." << endl;
00101 return this->prepareExit(false);
00102 }
00103 }
00104 }
00105
00106
00107 unsigned nCycle = 0;
00108 unsigned nFailed = 0;
00109 while( nCycle < cycles_ ) {
00110 for( int i=0;i<trainable_.size();i++ ) {
00111
00112 if( nCycle++ >= cycles_ ) return this->prepareExit((this->nTrained()>0));
00113
00114
00115 auto_ptr<SprEmptyFilter> temp(bootstrap_->weightedReplica());
00116 if( temp->size() != data_->size() ) {
00117 cerr << "Failed to generate bootstrap replica." << endl;
00118 return this->prepareExit(false);
00119 }
00120
00121
00122 SprAbsClassifier* c = trainable_[i];
00123 if( !c->setData(temp.get()) ) {
00124 cerr << "Unable to set data for classifier " << i << endl;
00125 return this->prepareExit(false);
00126 }
00127 if( !c->train(verbose) ) {
00128 cerr << "ArcE4 failed to train classifier " << i
00129 << ". Continuing..."<< endl;
00130 if( ++nFailed >= cycles_ ) {
00131 cout << "Exiting after failed to train " << nFailed
00132 << " classifiers." << endl;
00133 return this->prepareExit((this->nTrained()>0));
00134 }
00135 else
00136 continue;
00137 }
00138
00139
00140 SprAbsTrainedClassifier* t = c->makeTrained();
00141 if( t == 0 ) {
00142 cerr << "ArcE4 failed to train classifier " << i
00143 << ". Continuing..."<< endl;
00144 if( ++nFailed >= cycles_ ) {
00145 cout << "Exiting after failed to train " << nFailed
00146 << " classifiers." << endl;
00147 return this->prepareExit((this->nTrained()>0));
00148 }
00149 else
00150 continue;
00151 }
00152 trained_.push_back(pair<const SprAbsTrainedClassifier*,bool>(t,true));
00153
00154
00155 this->reweight(t);
00156 if( verbose > 1 ) {
00157 cout << "After reweighting: W1=" << data_->weightInClass(cls1_)
00158 << " W0=" << data_->weightInClass(cls0_)
00159 << " N1=" << data_->ptsInClass(cls1_)
00160 << " N0=" << data_->ptsInClass(cls0_) << endl;
00161 }
00162
00163
00164 if( verbose>1 || (nCycle%100)==0 ) {
00165 cout << "Done cycle " << nCycle << endl;
00166 }
00167
00168
00169 if( valData_ != 0 ) {
00170
00171 int tsize = trained_.size();
00172 for( int i=0;i<valData_->size();i++ ) {
00173 const SprPoint* p = (*valData_)[i];
00174 if( discrete_ ) {
00175 if( t->accept(p) )
00176 valBeta_[i] = ((tsize-1)*valBeta_[i] + 1)/tsize;
00177 else
00178 valBeta_[i] = ((tsize-1)*valBeta_[i] - 1)/tsize;
00179 }
00180 else
00181 valBeta_[i] = ((tsize-1)*valBeta_[i] + t->response(p))/tsize;
00182 }
00183
00184
00185 if( valPrint_!=0 && (nCycle%valPrint_)==0 ) {
00186 if( !this->printValidation(nCycle) ) {
00187 cerr << "Unable to print out validation data." << endl;
00188 return this->prepareExit(false);
00189 }
00190 }
00191 }
00192 }
00193 }
00194
00195
00196 return this->prepareExit((this->nTrained()>0));
00197 }
00198
00199
00200 bool SprArcE4::prepareExit(bool status)
00201 {
00202
00203 data_->setWeights(initialDataWeights_);
00204
00205
00206 return SprBagger::prepareExit(status);
00207 }
00208
00209
00210 void SprArcE4::reweight(const SprAbsTrainedClassifier* t)
00211 {
00212 unsigned size = data_->size();
00213 assert( size == initialDataWeights_.size() );
00214 assert( size == response_.size() );
00215 for( int i=0;i<size;i++ ) {
00216 const SprPoint* p = (*data_)[i];
00217
00218
00219 double& resp = response_[i].first;
00220 double& wresp = response_[i].second;
00221 resp = wresp*resp + t->response(p);
00222 wresp += 1.;
00223 resp /= wresp;
00224
00225
00226 int cls = -1;
00227 if( p->class_ == cls0_ )
00228 cls = 0;
00229 else if( p->class_ == cls1_ )
00230 cls = 1;
00231 if( cls > -1 ) {
00232 double error = wresp * (resp - cls);
00233 double w = initialDataWeights_[i] * (1.+pow(fabs(error),4));
00234 data_->setW(i,w);
00235 }
00236 }
00237 }