00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprFisher.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprPoint.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprUtils.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00009
00010 #include <stdio.h>
00011 #include <iostream>
00012 #include <cassert>
00013 #include <iomanip>
00014 #include <cmath>
00015 #include <vector>
00016
00017 using namespace std;
00018
00019
00020 SprFisher::SprFisher(SprAbsFilter* data, int mode)
00021 :
00022 SprAbsClassifier(data),
00023 mode_(mode),
00024 cls0_(0),
00025 cls1_(1),
00026 dim_(data->dim()),
00027 linear_(dim_),
00028 quadr_(dim_),
00029 cterm_(0)
00030 {
00031 this->setClasses();
00032 }
00033
00034
00035 bool SprFisher::setData(SprAbsFilter* data)
00036 {
00037 assert( data != 0 );
00038 data_ = data;
00039 return this->reset();
00040 }
00041
00042
00043 SprTrainedFisher* SprFisher::makeTrained() const
00044 {
00045
00046 SprTrainedFisher* t = 0;
00047 if( mode_ == 1 )
00048 t = new SprTrainedFisher(linear_,cterm_);
00049 else if( mode_ == 2 )
00050 t = new SprTrainedFisher(linear_,quadr_,cterm_);
00051
00052
00053 vector<string> vars;
00054 data_->vars(vars);
00055 t->setVars(vars);
00056
00057
00058 return t;
00059 }
00060
00061
00062 bool SprFisher::train(int verbose)
00063 {
00064
00065 SprVector mean0(dim_), mean1(dim_);
00066 SprSymMatrix cov0(dim_), cov1(dim_);
00067
00068
00069 for( int i=0;i<dim_;i++ ) {
00070 mean0[i] = 0;
00071 mean1[i] = 0;
00072 for( int j=i;j<dim_;j++ ) {
00073 cov0[i][j] = 0;
00074 cov1[i][j] = 0;
00075 }
00076 }
00077
00078
00079 unsigned size = data_->size();
00080 if( size == 0 ) {
00081 cerr << "No points in data." << endl;
00082 return false;
00083 }
00084 double size0(0), size1(0);
00085 double w = 0;
00086 double r1(0), r2(0);
00087 for( int i=0;i<size;i++ ) {
00088 const SprPoint* p = (*data_)[i];
00089 int cls = p->class_;
00090 if( cls==cls0_ || cls==cls1_ ) {
00091 w = data_->w(i);
00092
00093 if( cls == cls0_ )
00094 size0 += w;
00095 else if( cls == cls1_ )
00096 size1 += w;
00097
00098 for( int j=0;j<dim_;j++ ) {
00099 r1 = w * (p->x_)[j];
00100 if( cls == cls0_ )
00101 mean0[j] += r1;
00102 else if( cls == cls1_ )
00103 mean1[j] += r1;
00104 for( int k=j;k<dim_;k++ ) {
00105 r2 = r1*((p->x_)[k]);
00106 if( cls == cls0_ )
00107 cov0[j][k] += r2;
00108 else if( cls == cls1_ )
00109 cov1[j][k] += r2;
00110 }
00111 }
00112 }
00113 }
00114 double eps = SprUtils::eps();
00115 if( size0<eps || size1<eps ) {
00116 cerr << "Cannot find points for a class: "
00117 << size0 << " " << size1 << endl;
00118 return false;
00119 }
00120
00121
00122 mean0 /= size0;
00123 mean1 /= size1;
00124 cov0 /= size0;
00125 cov1 /= size1;
00126 SprSymMatrix meansq0 = vT_times_v(mean0);
00127 SprSymMatrix meansq1 = vT_times_v(mean1);
00128 cov0 -= meansq0;
00129 cov1 -= meansq1;
00130 if( mode_ == 1 ) {
00131 cov0 = (size0*cov0+size1*cov1) / (size0+size1);
00132 }
00133
00134
00135 if( verbose > 1 ) {
00136 cout << "Sample means:" << endl;
00137 for( int i=0;i<dim_;i++ )
00138 cout << i << ": " << mean0[i] << " " << mean1[i] << endl;
00139 cout << "Sample covariance matrices:" << endl;
00140 if( mode_ == 2 )
00141 cout << "Class " << cls0_ << endl;
00142 for( int i=0;i<dim_;i++ ) {
00143 for( int j=0;j<dim_;j++ )
00144 cout << " " << cov0[i][j];
00145 cout << endl;
00146 }
00147 if( mode_ == 2 ) {
00148 cout << "Class " << cls1_ << endl;
00149 for( int i=0;i<dim_;i++ ) {
00150 for( int j=0;j<dim_;j++ )
00151 cout << " " << cov1[i][j];
00152 cout << endl;
00153 }
00154 }
00155 }
00156
00157
00158 int ifail = 0;
00159 cov0.invert(ifail);
00160 if( ifail != 0 ) {
00161 cerr << "Unable to invert matrix." << endl;
00162 return false;
00163 }
00164 if( verbose > 1 ) {
00165 cout << "Inverse matrices:" << endl;
00166 if( mode_ == 2 )
00167 cout << "Class " << cls0_ << endl;
00168 for( int i=0;i<dim_;i++ ) {
00169 for( int j=0;j<dim_;j++ )
00170 cout << " " << cov0[i][j];
00171 cout << endl;
00172 }
00173 }
00174 if( mode_ == 2 ) {
00175 cov1.invert(ifail);
00176 if( ifail != 0 ) {
00177 cerr << "Unable to invert matrix." << endl;
00178 return false;
00179 }
00180 if( verbose > 1 ) {
00181 cout << "Class " << cls1_ << endl;
00182 for( int i=0;i<dim_;i++ ) {
00183 for( int j=0;j<dim_;j++ )
00184 cout << " " << cov1[i][j];
00185 cout << endl;
00186 }
00187 }
00188 }
00189 cterm_ = log(size1/size0);
00190 if( mode_ == 1 ) {
00191 linear_ = cov0 * (mean1-mean0);
00192 cterm_ += -0.5 * (dot(mean1,cov0*mean1) - dot(mean0,cov0*mean0));
00193 }
00194 else if( mode_ == 2 ) {
00195 linear_ = cov1*mean1 - cov0*mean0;
00196 quadr_ = -0.5 * (cov1-cov0);
00197 cterm_ += -0.5 * (dot(mean1,cov1*mean1) - dot(mean0,cov0*mean0));
00198 double d0 = cov0.determinant();
00199 double d1 = cov1.determinant();
00200 cterm_ += 0.5*log(d1/d0);
00201 }
00202
00203
00204 return true;
00205 }
00206
00207
00208 void SprFisher::print(std::ostream& os) const
00209 {
00210 os << "Trained Fisher " << SprVersion << endl;
00211 os << "Fisher dimensionality: " << linear_.num_row() << endl;
00212 os << "Fisher response: F = C + T(L)*X + T(X)*Q*X; T is transposition"
00213 << endl;
00214 os << "By default logit transform is applied: F <- 1/[1+exp(-F)]" << endl;
00215 os << "Fisher order: " << mode_ << endl;
00216 os << "Const term (C): " << cterm_ << endl;
00217 os << "Linear Part (L):" << endl;
00218 for( int i=0;i<linear_.num_row();i++ )
00219 os << setw(10) << linear_[i] << " ";
00220 os << endl;
00221 if( mode_ == 2 ) {
00222 os << "Quadratic Part (Q):" << endl;
00223 for( int i=0;i<quadr_.num_row();i++ ) {
00224 for( int j=0;j<quadr_.num_col();j++ ) {
00225 os << setw(10) << quadr_[i][j] << " ";
00226 }
00227 os << endl;
00228 }
00229 }
00230 }
00231
00232
00233 void SprFisher::setClasses()
00234 {
00235 vector<SprClass> classes;
00236 data_->classes(classes);
00237 int size = classes.size();
00238 if( size > 0 ) cls0_ = classes[0];
00239 if( size > 1 ) cls1_ = classes[1];
00240 cout << "Classes for Fisher are set to " << cls0_ << " " << cls1_ << endl;
00241 }