CMS 3D CMS Logo

SprTrainedMultiClassLearner.cc

Go to the documentation of this file.
00001 //$Id: SprTrainedMultiClassLearner.cc,v 1.2 2007/09/21 22:32:10 narsky Exp $
00002 
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedMultiClassLearner.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00009 
00010 #include <algorithm>
00011 #include <functional>
00012 #include <iomanip>
00013 #include <cassert>
00014 
00015 using namespace std;
00016 
00017 
00018 struct STMCLCmpPairSecond
00019   : public binary_function<pair<const int,double>,
00020                            pair<const int,double>,bool> {
00021   bool operator()(const pair<const int,double>& l, 
00022                   const pair<const int,double>& r) const {
00023     return (l.second < r.second);
00024   }
00025 };
00026 
00027 
00028 SprTrainedMultiClassLearner::SprTrainedMultiClassLearner(
00029        const SprMatrix& indicator,
00030        const std::vector<int>& mapper,
00031        const std::vector<std::pair<const SprAbsTrainedClassifier*,bool> >& 
00032        classifiers)
00033   :
00034   SprAbsTrainedMultiClassLearner(),
00035   indicator_(indicator),
00036   mapper_(mapper),
00037   classifiers_(classifiers),
00038   loss_(0),
00039   trans_(0)
00040 {
00041   assert( !classifiers_.empty() );
00042   assert( indicator_.num_row() > 0 );
00043   assert( indicator_.num_col() == classifiers_.size() );
00044   if( mapper_.empty() ) {
00045     unsigned n = indicator_.num_row();
00046     mapper_.resize(n);
00047     for( int i=0;i<n;i++ ) mapper_[i] = i;
00048   }
00049   assert( mapper_.size() == indicator_.num_row() );
00050   // set default loss
00051   this->setLoss(&SprLoss::quadratic,
00052                 &SprTransformation::zeroOneToMinusPlusOne);
00053   cout << "Loss for trained multi-class learner by default set to "
00054        << "quadratic." << endl;
00055 }
00056 
00057 
00058 SprTrainedMultiClassLearner::SprTrainedMultiClassLearner(
00059                              const SprTrainedMultiClassLearner& other)
00060   :
00061   SprAbsTrainedMultiClassLearner(other),
00062   indicator_(other.indicator_),
00063   mapper_(other.mapper_),
00064   classifiers_(),
00065   loss_(other.loss_),
00066   trans_(other.trans_)
00067 {
00068   for( int i=0;i<other.classifiers_.size();i++ ) {
00069     const SprAbsTrainedClassifier* t = other.classifiers_[i].first->clone();
00070     classifiers_.push_back(pair<const SprAbsTrainedClassifier*,bool>(t,true));
00071   }
00072   assert( indicator_.num_col() == classifiers_.size() );
00073 }
00074 
00075 
00076 void SprTrainedMultiClassLearner::destroy()
00077 {
00078   for( int i=0;i<classifiers_.size();i++ ) {
00079     if( classifiers_[i].second )
00080       delete classifiers_[i].first;
00081   }
00082 }
00083  
00084 
00085 int SprTrainedMultiClassLearner::response(const std::vector<double>& input,
00086                                           std::map<int,double>& output) const
00087 {
00088   // sanity check
00089   assert( loss_ != 0 );
00090 
00091   // compute vector of responses
00092   vector<double> response(classifiers_.size());
00093   for( int i=0;i<classifiers_.size();i++ ) {
00094     double r = classifiers_[i].first->response(input);
00095     response[i] = ( trans_==0 ? r : trans_(r) );
00096   }
00097 
00098   // evaluate consistency with each row
00099   output.clear();
00100   unsigned ncol = indicator_.num_col();
00101   for( int i=0;i<indicator_.num_row();i++ ) {
00102     double rowLoss = 0;
00103     for( int j=0;j<ncol;j++ )
00104       rowLoss += loss_(int(indicator_[i][j]),response[j]);
00105     rowLoss /= ncol;
00106     output.insert(pair<const int,double>(mapper_[i],rowLoss));
00107   }
00108 
00109   // find minimal loss
00110   map<int,double>::const_iterator iter 
00111     = min_element(output.begin(),output.end(),STMCLCmpPairSecond());
00112   return iter->first;
00113 }
00114 
00115 
00116 void SprTrainedMultiClassLearner::print(std::ostream& os) const 
00117 {
00118   os << "Trained MultiClassLearner " << SprVersion << endl;
00119 
00120   // print matrix
00121   this->printIndicatorMatrix(os);
00122 
00123   // print classifiers
00124   assert( indicator_.num_col() == classifiers_.size() );
00125   for( int j=0;j<classifiers_.size();j++ ) {
00126     os << "Multi class learner subclassifier: " << j << endl;
00127     classifiers_[j].first->print(os);
00128   }
00129 }
00130 
00131 
00132 void SprTrainedMultiClassLearner::classes(std::vector<int>& classes) const 
00133 { 
00134   classes = mapper_; 
00135   stable_sort(classes.begin(),classes.end());
00136 }
00137 
00138 
00139 void SprTrainedMultiClassLearner::printIndicatorMatrix(std::ostream& os) const
00140 {
00141   os << "Indicator matrix:" << endl;
00142   os << setw(20) << "Classes/Classifiers" << " : " 
00143      << mapper_.size() << " " << classifiers_.size() << endl;
00144   os << "=========================================================" << endl;
00145   for( int i=0;i<indicator_.num_row();i++ ) {
00146     os << setw(20) << mapper_[i] << " : ";
00147     for( int j=0;j<indicator_.num_col();j++ ) 
00148       os << setw(2) << indicator_[i][j] << " ";
00149     os << endl;
00150   }
00151   os << "=========================================================" << endl;
00152 }

Generated on Tue Jun 9 17:42:04 2009 for CMSSW by  doxygen 1.5.4