00001
00002
00003 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00004 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsCombiner.hh"
00005 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsTrainedClassifier.hh"
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsFilter.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprPoint.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprClass.hh"
00011
00012 #include <cassert>
00013
00014 using namespace std;
00015
00016
00017 SprAbsCombiner::SprAbsCombiner(SprAbsFilter* data)
00018 :
00019 SprAbsClassifier(data),
00020 classifiers_(),
00021 classifierLabels_(),
00022 features_(0)
00023 {}
00024
00025
00026 SprAbsCombiner::SprAbsCombiner(SprAbsFilter* data,
00027 const std::vector<
00028 const SprAbsTrainedClassifier*>& c,
00029 const std::vector<std::string>& cLabels)
00030 :
00031 SprAbsClassifier(data),
00032 classifiers_(c),
00033 classifierLabels_(cLabels),
00034 features_(0)
00035 {
00036 assert( !classifiers_.empty() );
00037 assert( classifiers_.size() == classifierLabels_.size() );
00038 }
00039
00040
00041 bool SprAbsCombiner::makeFeatures()
00042 {
00043
00044 int nClassifiers = classifiers_.size();
00045 if( nClassifiers == 0 ) return false;
00046 assert( nClassifiers == classifierLabels_.size() );
00047
00048
00049 SprData* features = new SprData("features",classifierLabels_);
00050 vector<double> r(nClassifiers);
00051 for( int i=0;i<data_->size();i++ ) {
00052 const SprPoint* p = (*data_)[i];
00053 for( int j=0;j<nClassifiers;j++ )
00054 r[j] = classifiers_[j]->response(p);
00055 features->insert(p->class_,r);
00056 }
00057
00058
00059 vector<double> weights;
00060 data_->weights(weights);
00061
00062
00063 vector<SprClass> classes;
00064 data_->classes(classes);
00065
00066
00067 features_ = new SprEmptyFilter(features,classes,weights,true);
00068
00069
00070 return true;
00071 }