CMS 3D CMS Logo

SprFisher.cc

Go to the documentation of this file.
00001 //$Id: SprFisher.cc,v 1.2 2007/09/21 22:32:10 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprFisher.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprPoint.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00009 
00010 #include <stdio.h>
00011 #include <iostream>
00012 #include <cassert>
00013 #include <iomanip>
00014 #include <cmath>
00015 #include <vector>
00016 
00017 using namespace std;
00018 
00019 
00020 SprFisher::SprFisher(SprAbsFilter* data, int mode) 
00021   : 
00022   SprAbsClassifier(data),
00023   mode_(mode),
00024   cls0_(0),
00025   cls1_(1),
00026   dim_(data->dim()),
00027   linear_(dim_),
00028   quadr_(dim_),
00029   cterm_(0)
00030 {
00031   this->setClasses();
00032 }
00033 
00034 
00035 bool SprFisher::setData(SprAbsFilter* data)
00036 {
00037   assert( data != 0 );
00038   data_ = data;
00039   return this->reset();
00040 }
00041 
00042 
00043 SprTrainedFisher* SprFisher::makeTrained() const
00044 {
00045   // make
00046   SprTrainedFisher* t = 0;
00047   if(      mode_ == 1 )
00048     t = new SprTrainedFisher(linear_,cterm_);
00049   else if( mode_ == 2 )
00050     t = new SprTrainedFisher(linear_,quadr_,cterm_);
00051 
00052   // vars
00053   vector<string> vars;
00054   data_->vars(vars);
00055   t->setVars(vars);
00056 
00057   // exit
00058   return t;
00059 }
00060 
00061 
00062 bool SprFisher::train(int verbose)
00063 {
00064   // init
00065   SprVector mean0(dim_), mean1(dim_);
00066   SprSymMatrix cov0(dim_), cov1(dim_);
00067 
00068   // be paranoid and fill with zeros
00069   for( int i=0;i<dim_;i++ ) {
00070     mean0[i] = 0;
00071     mean1[i] = 0;
00072     for( int j=i;j<dim_;j++ ) {
00073       cov0[i][j] = 0;
00074       cov1[i][j] = 0;
00075     }
00076   }
00077 
00078   // loop through points to compute mean vectors and covariance matrices
00079   unsigned size = data_->size();
00080   if( size == 0 ) {
00081     cerr << "No points in data." << endl;
00082     return false;
00083   }
00084   double size0(0), size1(0);
00085   double w = 0;
00086   double r1(0), r2(0);
00087   for( int i=0;i<size;i++ ) {
00088     const SprPoint* p = (*data_)[i];
00089     int cls = p->class_;
00090     if( cls==cls0_ || cls==cls1_ ) {
00091       w = data_->w(i);
00092       // increment weights
00093       if(      cls == cls0_ )
00094         size0 += w;
00095       else if( cls == cls1_ )
00096         size1 += w;
00097       // loop through dimensions
00098       for( int j=0;j<dim_;j++ ) {
00099         r1 = w * (p->x_)[j];
00100         if(      cls == cls0_ )
00101           mean0[j] += r1;
00102         else if( cls == cls1_ )
00103           mean1[j] += r1;
00104         for( int k=j;k<dim_;k++ ) {
00105           r2 = r1*((p->x_)[k]);
00106           if(      cls == cls0_ )
00107             cov0[j][k] += r2;
00108           else if( cls == cls1_ )
00109             cov1[j][k] += r2;
00110         }
00111       }
00112     }
00113   }
00114   double eps = SprUtils::eps();
00115   if( size0<eps || size1<eps ) {
00116     cerr << "Cannot find points for a class: " 
00117          << size0 << " " << size1 << endl;
00118     return false;
00119   }
00120 
00121   // normalize and compute covariances
00122   mean0 /= size0;
00123   mean1 /= size1;
00124   cov0 /= size0;
00125   cov1 /= size1;
00126   SprSymMatrix meansq0 = vT_times_v(mean0);
00127   SprSymMatrix meansq1 = vT_times_v(mean1);
00128   cov0 -= meansq0;
00129   cov1 -= meansq1;
00130   if( mode_ == 1 ) {
00131     cov0 = (size0*cov0+size1*cov1) / (size0+size1);
00132   }
00133 
00134   // print out
00135   if( verbose > 1 ) {
00136     cout << "Sample means:" << endl;
00137     for( int i=0;i<dim_;i++ )
00138       cout << i << ": " << mean0[i] << " " << mean1[i] << endl;
00139     cout << "Sample covariance matrices:" << endl;
00140     if( mode_ == 2 )
00141       cout << "Class " << cls0_ << endl;
00142     for( int i=0;i<dim_;i++ ) {
00143       for( int j=0;j<dim_;j++ )
00144         cout << " " << cov0[i][j];
00145       cout << endl;
00146     }
00147     if( mode_ == 2 ) {
00148       cout << "Class " << cls1_ << endl;
00149       for( int i=0;i<dim_;i++ ) {
00150         for( int j=0;j<dim_;j++ )
00151           cout << " " << cov1[i][j];
00152         cout << endl;
00153       }
00154     }
00155   }
00156 
00157   // compute Fisher coefficients
00158   int ifail = 0;
00159   cov0.invert(ifail);
00160   if( ifail != 0 ) {
00161     cerr << "Unable to invert matrix." << endl;
00162     return false;
00163   }
00164   if( verbose > 1 ) {
00165     cout << "Inverse matrices:" << endl;
00166     if( mode_ == 2 )
00167       cout << "Class " << cls0_ << endl;
00168     for( int i=0;i<dim_;i++ ) {
00169       for( int j=0;j<dim_;j++ )
00170         cout << " " << cov0[i][j];
00171       cout << endl;
00172     }
00173   }
00174   if( mode_ == 2 ) {
00175     cov1.invert(ifail);
00176     if( ifail != 0 ) {
00177       cerr << "Unable to invert matrix." << endl;
00178       return false;
00179     }
00180     if( verbose > 1 ) {
00181       cout << "Class " << cls1_ << endl;
00182       for( int i=0;i<dim_;i++ ) {
00183         for( int j=0;j<dim_;j++ )
00184           cout << " " << cov1[i][j];
00185         cout << endl;
00186       }
00187     }
00188   }
00189   cterm_ = log(size1/size0);
00190   if(      mode_ == 1 ) {
00191     linear_ = cov0 * (mean1-mean0);
00192     cterm_ += -0.5 * (dot(mean1,cov0*mean1) - dot(mean0,cov0*mean0));
00193   }
00194   else if( mode_ == 2 ) {
00195     linear_ = cov1*mean1 - cov0*mean0;
00196     quadr_ = -0.5 * (cov1-cov0);
00197     cterm_ += -0.5 * (dot(mean1,cov1*mean1) - dot(mean0,cov0*mean0));
00198     double d0 = cov0.determinant();
00199     double d1 = cov1.determinant();
00200     cterm_ += 0.5*log(d1/d0);//inverted matrices => must have "+", not "-"
00201   }
00202 
00203   // exit
00204   return true;
00205 }
00206 
00207 
00208 void SprFisher::print(std::ostream& os) const
00209 {
00210   os << "Trained Fisher " << SprVersion << endl;
00211   os << "Fisher dimensionality: " << linear_.num_row() << endl;
00212   os << "Fisher response: F = C + T(L)*X + T(X)*Q*X; T is transposition" 
00213      << endl;
00214   os << "By default logit transform is applied: F <- 1/[1+exp(-F)]" << endl;
00215   os << "Fisher order: " << mode_ << endl;
00216   os << "Const term (C): " << cterm_ << endl;
00217   os << "Linear Part (L):" << endl;
00218   for( int i=0;i<linear_.num_row();i++ )
00219     os << setw(10) << linear_[i] << " ";
00220   os << endl;
00221   if( mode_ == 2 ) {
00222     os << "Quadratic Part (Q):" << endl;
00223     for( int i=0;i<quadr_.num_row();i++ ) {
00224       for( int j=0;j<quadr_.num_col();j++ ) {
00225         os << setw(10) << quadr_[i][j] << " ";
00226       }
00227       os << endl;
00228     }
00229   }
00230 }
00231 
00232 
00233 void SprFisher::setClasses() 
00234 {
00235   vector<SprClass> classes;
00236   data_->classes(classes);
00237   int size = classes.size();
00238   if( size > 0 ) cls0_ = classes[0];
00239   if( size > 1 ) cls1_ = classes[1];
00240   cout << "Classes for Fisher are set to " << cls0_ << " " << cls1_ << endl;
00241 } 

Generated on Tue Jun 9 17:42:02 2009 for CMSSW by  doxygen 1.5.4