00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedMultiClassLearner.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprDefs.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprLoss.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprTransformation.hh"
00009
00010 #include <algorithm>
00011 #include <functional>
00012 #include <iomanip>
00013 #include <cassert>
00014
00015 using namespace std;
00016
00017
00018 struct STMCLCmpPairSecond
00019 : public binary_function<pair<const int,double>,
00020 pair<const int,double>,bool> {
00021 bool operator()(const pair<const int,double>& l,
00022 const pair<const int,double>& r) const {
00023 return (l.second < r.second);
00024 }
00025 };
00026
00027
00028 SprTrainedMultiClassLearner::SprTrainedMultiClassLearner(
00029 const SprMatrix& indicator,
00030 const std::vector<int>& mapper,
00031 const std::vector<std::pair<const SprAbsTrainedClassifier*,bool> >&
00032 classifiers)
00033 :
00034 SprAbsTrainedMultiClassLearner(),
00035 indicator_(indicator),
00036 mapper_(mapper),
00037 classifiers_(classifiers),
00038 loss_(0),
00039 trans_(0)
00040 {
00041 assert( !classifiers_.empty() );
00042 assert( indicator_.num_row() > 0 );
00043 assert( indicator_.num_col() == classifiers_.size() );
00044 if( mapper_.empty() ) {
00045 unsigned n = indicator_.num_row();
00046 mapper_.resize(n);
00047 for( int i=0;i<n;i++ ) mapper_[i] = i;
00048 }
00049 assert( mapper_.size() == indicator_.num_row() );
00050
00051 this->setLoss(&SprLoss::quadratic,
00052 &SprTransformation::zeroOneToMinusPlusOne);
00053 cout << "Loss for trained multi-class learner by default set to "
00054 << "quadratic." << endl;
00055 }
00056
00057
00058 SprTrainedMultiClassLearner::SprTrainedMultiClassLearner(
00059 const SprTrainedMultiClassLearner& other)
00060 :
00061 SprAbsTrainedMultiClassLearner(other),
00062 indicator_(other.indicator_),
00063 mapper_(other.mapper_),
00064 classifiers_(),
00065 loss_(other.loss_),
00066 trans_(other.trans_)
00067 {
00068 for( int i=0;i<other.classifiers_.size();i++ ) {
00069 const SprAbsTrainedClassifier* t = other.classifiers_[i].first->clone();
00070 classifiers_.push_back(pair<const SprAbsTrainedClassifier*,bool>(t,true));
00071 }
00072 assert( indicator_.num_col() == classifiers_.size() );
00073 }
00074
00075
00076 void SprTrainedMultiClassLearner::destroy()
00077 {
00078 for( int i=0;i<classifiers_.size();i++ ) {
00079 if( classifiers_[i].second )
00080 delete classifiers_[i].first;
00081 }
00082 }
00083
00084
00085 int SprTrainedMultiClassLearner::response(const std::vector<double>& input,
00086 std::map<int,double>& output) const
00087 {
00088
00089 assert( loss_ != 0 );
00090
00091
00092 vector<double> response(classifiers_.size());
00093 for( int i=0;i<classifiers_.size();i++ ) {
00094 double r = classifiers_[i].first->response(input);
00095 response[i] = ( trans_==0 ? r : trans_(r) );
00096 }
00097
00098
00099 output.clear();
00100 unsigned ncol = indicator_.num_col();
00101 for( int i=0;i<indicator_.num_row();i++ ) {
00102 double rowLoss = 0;
00103 for( int j=0;j<ncol;j++ )
00104 rowLoss += loss_(int(indicator_[i][j]),response[j]);
00105 rowLoss /= ncol;
00106 output.insert(pair<const int,double>(mapper_[i],rowLoss));
00107 }
00108
00109
00110 map<int,double>::const_iterator iter
00111 = min_element(output.begin(),output.end(),STMCLCmpPairSecond());
00112 return iter->first;
00113 }
00114
00115
00116 void SprTrainedMultiClassLearner::print(std::ostream& os) const
00117 {
00118 os << "Trained MultiClassLearner " << SprVersion << endl;
00119
00120
00121 this->printIndicatorMatrix(os);
00122
00123
00124 assert( indicator_.num_col() == classifiers_.size() );
00125 for( int j=0;j<classifiers_.size();j++ ) {
00126 os << "Multi class learner subclassifier: " << j << endl;
00127 classifiers_[j].first->print(os);
00128 }
00129 }
00130
00131
00132 void SprTrainedMultiClassLearner::classes(std::vector<int>& classes) const
00133 {
00134 classes = mapper_;
00135 stable_sort(classes.begin(),classes.end());
00136 }
00137
00138
00139 void SprTrainedMultiClassLearner::printIndicatorMatrix(std::ostream& os) const
00140 {
00141 os << "Indicator matrix:" << endl;
00142 os << setw(20) << "Classes/Classifiers" << " : "
00143 << mapper_.size() << " " << classifiers_.size() << endl;
00144 os << "=========================================================" << endl;
00145 for( int i=0;i<indicator_.num_row();i++ ) {
00146 os << setw(20) << mapper_[i] << " : ";
00147 for( int j=0;j<indicator_.num_col();j++ )
00148 os << setw(2) << indicator_[i][j] << " ";
00149 os << endl;
00150 }
00151 os << "=========================================================" << endl;
00152 }