CMS 3D CMS Logo

SprLogitR.cc

Go to the documentation of this file.
00001 //$Id: SprLogitR.cc,v 1.2 2007/09/21 22:32:10 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprLogitR.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprPoint.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprFisher.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00010 
00011 #include "PhysicsTools/StatPatternRecognition/src/SprMatrix.hh"
00012 #include "PhysicsTools/StatPatternRecognition/src/SprSymMatrix.hh"
00013 
00014 #include <stdio.h>
00015 #include <iostream>
00016 #include <cassert>
00017 #include <iomanip>
00018 #include <cmath>
00019 #include <vector>
00020 
00021 using namespace std;
00022 
00023 
00024 SprLogitR::SprLogitR(SprAbsFilter* data, double eps, double updateFactor)
00025   :
00026   SprAbsClassifier(data),
00027   cls0_(0),
00028   cls1_(1),
00029   dim_(data->dim()),
00030   eps_(eps),
00031   updateFactor_(updateFactor),
00032   nIterAllowed_(100),
00033   beta0_(0),
00034   beta_(dim_),
00035   beta0Supplied_(0),
00036   betaSupplied_()
00037 {
00038   assert( eps_ > 0 );
00039   this->setClasses();
00040 }
00041 
00042 
00043 SprLogitR::SprLogitR(SprAbsFilter* data, 
00044                      double beta0, const SprVector& beta,
00045                      double eps, double updateFactor)
00046   :
00047   SprAbsClassifier(data),
00048   cls0_(0),
00049   cls1_(1),
00050   dim_(data->dim()),
00051   eps_(eps),
00052   updateFactor_(updateFactor),
00053   nIterAllowed_(100),
00054   beta0_(0),
00055   beta_(dim_),
00056   beta0Supplied_(beta0),
00057   betaSupplied_(beta)
00058 {
00059   assert( eps_ > 0 );
00060   assert( updateFactor_ > 0 );
00061   this->setClasses();
00062 }
00063 
00064 
00065 bool SprLogitR::setData(SprAbsFilter* data)
00066 {
00067   assert( data != 0 );
00068   data_ = data;
00069   return this->reset();
00070 }
00071 
00072 
00073 SprTrainedLogitR* SprLogitR::makeTrained() const
00074 {
00075   // make
00076   SprTrainedLogitR* t = new SprTrainedLogitR(beta0_,beta_);
00077 
00078   // vars
00079   vector<string> vars;
00080   data_->vars(vars);
00081   t->setVars(vars);
00082 
00083   // exit
00084   return t;
00085 }
00086 
00087 
00088 bool SprLogitR::train(int verbose)
00089 {
00090   // initialize to user-supplied values
00091   if( dim_ == betaSupplied_.num_row() ) {
00092     beta0_ = beta0Supplied_;
00093     beta_ = betaSupplied_;
00094   }
00095   else {// obtain initial estimates from LDA
00096     // message
00097     if( verbose > 0 ) {
00098       cout << "Obtaining initial estimates of Logit coefficients " 
00099            << "from LDA..." << endl;
00100     }
00101 
00102     // train LDA
00103     SprFisher fisher(data_,1);
00104     if( fisher.train(verbose) ) {
00105       beta0_ = fisher.cterm();
00106       fisher.linear(beta_);
00107       if( verbose > 0 ) {
00108         cout << "...Obtained estimates of Logit coefficients from LDA." << endl;
00109       }
00110     }
00111     else {
00112       cout << "Unable to train LDA. Will use zeros for initial estimates of " 
00113            << "Logit coefficients." << endl;
00114       for( int i=0;i<beta_.num_row();i++ ) beta_[i] = 0;
00115     }
00116 
00117   }// end of LDA
00118   assert( beta_.num_row() == dim_ );
00119 
00120   //
00121   // prepare matrices
00122   //
00123 
00124   // renormalize weights
00125   unsigned n0 = data_->ptsInClass(cls0_);
00126   unsigned n1 = data_->ptsInClass(cls1_);
00127   assert( n0>0 && n1>0 );
00128   unsigned N = n0 + n1;
00129   double w0 = data_->weightInClass(cls0_);
00130   double w1 = data_->weightInClass(cls1_);
00131   assert( w0>0 && w1>0 );
00132   double wFactor = double(N)/(w0+w1); 
00133   SprVector weights(N);
00134   for( int i=0;i<N;i++ )
00135     weights[i] = wFactor*data_->w(i);
00136 
00137   // vector of fitted beta
00138   SprVector betafit(dim_+1);
00139   betafit[0] = beta0_;
00140   for( int i=1;i<betafit.num_row();i++ )
00141     betafit[i] = beta_[i-1];
00142 
00143   // vector of fitted probabilities
00144   SprVector prob;
00145 
00146   // input data matrix
00147   SprMatrix X(N,dim_+1);
00148   for( int i=0;i<N;i++ ) {
00149     X[i][0] = 1;
00150     const SprPoint* p = (*data_)[i]; 
00151     for( int j=1;j<dim_+1;j++ ) X[i][j] = (p->x_)[j-1];
00152   }
00153 
00154   // vector of classes
00155   SprVector y(N);
00156   for( int i=0;i<N;i++ ) {
00157     const SprPoint* p = (*data_)[i]; 
00158     if(      p->class_ == cls0_ )
00159       y[i] = 0;
00160     else if( p->class_ == cls1_ )
00161       y[i] = 1;
00162   }
00163 
00164   //
00165   // minimize
00166   //
00167   double eps = 1;
00168   int iter = 0;
00169   while( true ) {
00170     if( ++iter > nIterAllowed_ ) {
00171       cerr << "Logit exiting because number of alowed iterations exceeded: " 
00172            << iter << " " << nIterAllowed_ << endl;
00173       return false;
00174     }
00175     if( !this->iterate(y,X,weights,prob,betafit,eps) ) {
00176       cerr << "Unable to iterate Logit coefficients at step " << iter << endl;
00177       return false;
00178     }
00179     if( verbose > 0 )
00180       cout << "Iteration " << iter << " obtains epsilon " << eps << endl;
00181     if( eps < eps_ ) break;
00182   }
00183 
00184   // get back optimized betas
00185   beta0_ = betafit[0];
00186   for( int i=1;i<betafit.num_row();i++ )
00187     beta_[i-1] = betafit[i];
00188   
00189   // exit
00190   return true;
00191 }
00192 
00193 
00194 bool SprLogitR::iterate(const SprVector& y,
00195                         const SprMatrix& X, 
00196                         const SprVector& weights, 
00197                         SprVector& prob, 
00198                         SprVector& betafit, 
00199                         double& eps)
00200 {
00201   // get sample size
00202   const unsigned N = X.num_row();
00203   const unsigned D = X.num_col();
00204 
00205   // compute probabilities
00206   SprVector pold(N);
00207   if( prob.num_row() == 0 ) {
00208     for( int i=0;i<N;i++ )
00209       pold[i] = SprTransformation::logit(dot(X.sub(i+1,i+1,1,D).T(),betafit));
00210   }
00211   else
00212     pold = prob;
00213 
00214   // fill out W vector
00215   SprVector W(N);
00216   for( int i=0;i<N;i++ ) {
00217     W[i] = weights[i]*pold[i]*(1.-pold[i]);
00218     if( W[i] < 0 ) W[i] = 0;
00219     if( W[i] > 1 ) W[i] = 1;
00220   }
00221 
00222   // iterate
00223   SprSymMatrix XTWX(D);
00224   for( int i=0;i<D;i++ ) {
00225     for( int j=i;j<D;j++ ) {
00226       double res = 0;
00227       for( int n=0;n<N;n++ )
00228         res += W[n]*X[n][i]*X[n][j];
00229       XTWX[i][j] = res;
00230     }
00231   }
00232   int ifail = 0;
00233   XTWX.invert(ifail);
00234   if( ifail != 0 ) {
00235     cerr << "Unable to invert matrix for Logit coefficients." << endl;
00236     return false;
00237   }
00238   betafit += updateFactor_ * (XTWX * (X.T()*(y-pold)));
00239 
00240   // update probabilities
00241   SprVector pnew(N);
00242   for( int i=0;i<N;i++ )
00243     pnew[i] = SprTransformation::logit(dot(X.sub(i+1,i+1,1,D).T(),betafit));
00244   
00245   // compute eps per event
00246   eps = 0;
00247   for( int i=0;i<N;i++ )
00248     eps += fabs(pnew[i]-pold[i]);
00249   eps /= N;
00250 
00251   // exit
00252   prob = pnew;
00253   return true;
00254 }
00255 
00256 
00257 void SprLogitR::print(std::ostream& os) const
00258 {
00259   os << "Trained LogitR " << SprVersion << endl;
00260   os << "LogitR dimensionality: " << beta_.num_row() << endl;
00261   os << "LogitR response: L = Beta0 + Beta*X" << endl;  
00262   os << "By default logit transform is applied: L <- 1/[1+exp(-L)]" << endl;
00263   os << "Beta0: " << beta0_ << endl;
00264   os << "Vector of Beta Coefficients:" << endl;
00265   for( int i=0;i<beta_.num_row();i++ )
00266     os << setw(10) << beta_[i] << " ";
00267   os << endl;
00268 }
00269 
00270 
00271 void SprLogitR::setClasses() 
00272 {
00273   vector<SprClass> classes;
00274   data_->classes(classes);
00275   int size = classes.size();
00276   if( size > 0 ) cls0_ = classes[0];
00277   if( size > 1 ) cls1_ = classes[1];
00278   cout << "Classes for LogitR are set to " << cls0_ << " " << cls1_ << endl;
00279 } 

Generated on Tue Jun 9 17:42:03 2009 for CMSSW by  doxygen 1.5.4