CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_4_1_8_patch13/src/PhysicsTools/MVATrainer/src/TrainProcessor.cc

Go to the documentation of this file.
00001 #include <limits>
00002 #include <string>
00003 
00004 #include <TH1.h>
00005 
00006 #include "FWCore/PluginManager/interface/PluginManager.h"
00007 #include "FWCore/PluginManager/interface/PluginFactory.h"
00008 
00009 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00010 
00011 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00012 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00013 
00014 EDM_REGISTER_PLUGINFACTORY(PhysicsTools::TrainProcessor::PluginFactory,
00015                            "PhysicsToolsMVATrainer");
00016 
00017 namespace PhysicsTools {
00018 
00019 TrainProcessor::TrainProcessor(const char *name,
00020                                const AtomicId *id,
00021                                MVATrainer *trainer) :
00022         Source(*id), name(name), trainer(trainer), monitoring(0), monModule(0)
00023 {
00024 }
00025 
00026 TrainProcessor::~TrainProcessor()
00027 {
00028 }
00029 
00030 void TrainProcessor::doTrainBegin()
00031 {
00032         bool booked = false;
00033         unsigned int nBins = 50;
00034 
00035         if (!monitoring) {
00036                 const char *source = getName();
00037                 if (source) {
00038                         monitoring = trainer->bookMonitor(name + "_" + source);
00039                         monModule = trainer->bookMonitor(std::string("input_") +
00040                                                          source);
00041                 } else {
00042                         monModule = trainer->bookMonitor("output");
00043                         nBins = 400;
00044                 }
00045 
00046                 booked = monModule != 0;
00047         }
00048 
00049         if (booked) {
00050                 std::vector<SourceVariable*> inputs = getInputs().get();
00051                 for(std::vector<SourceVariable*>::const_iterator iter =
00052                         inputs.begin(); iter != inputs.end(); ++iter) {
00053 
00054                         SourceVariable *var = *iter;
00055                         std::string name =
00056                                 (const char*)var->getSource()->getName()
00057                                 + std::string("_")
00058                                 + (const char*)var->getName();
00059 
00060                         SigBkg pair;
00061                         pair.entries[0] = pair.entries[1] = 0;
00062                         pair.histo[0] = monModule->book<TH1F>(name + "_bkg",
00063                                 (name + "_bkg").c_str(),
00064                                 (name + " background").c_str(), nBins, 0, 0);
00065                         pair.histo[1] = monModule->book<TH1F>(name + "_sig",
00066                                 (name + "_sig").c_str(),
00067                                 (name + " signal").c_str(), nBins, 0, 0);
00068                         pair.underflow[0] = pair.underflow[1] = 0.0;
00069                         pair.overflow[0] = pair.overflow[1] = 0.0;
00070 
00071                         pair.sameBinning = true;        // use as default
00072                         if (monitoring) {
00073                                 pair.min = -std::numeric_limits<double>::infinity();
00074                                 pair.max = +std::numeric_limits<double>::infinity();
00075                         } else {
00076                                 pair.min = -99999.0;
00077                                 pair.max = +99999.0;
00078                         }
00079 
00080                         monHistos.push_back(pair);
00081                 }
00082         }
00083 
00084         trainBegin();
00085 }
00086 
00087 void TrainProcessor::doTrainData(const std::vector<double> *values,
00088                                  bool target, double weight,
00089                                  bool train, bool test)
00090 {
00091         if (monModule && test) {
00092                 for(std::vector<SigBkg>::iterator iter = monHistos.begin();
00093                     iter != monHistos.end(); ++iter) {
00094                         const std::vector<double> &vals =
00095                                         values[iter - monHistos.begin()];
00096                         for(std::vector<double>::const_iterator value =
00097                                 vals.begin(); value != vals.end(); ++value) {
00098 
00099                                 iter->entries[target]++;
00100 
00101                                 if (*value <= iter->min) {
00102                                         iter->underflow[target] += weight;
00103                                         continue;
00104                                 } else if (*value >= iter->max) {
00105                                         iter->overflow[target] += weight;
00106                                         continue;
00107                                 }
00108 
00109                                 iter->histo[target]->Fill(*value, weight);
00110 
00111                                 if (iter->sameBinning)
00112                                         iter->histo[!target]->Fill(*value, 0);
00113                         }
00114                 }
00115         }
00116 
00117         if (train)
00118                 trainData(values, target, weight);
00119         if (test)
00120                 testData(values, target, weight, train);
00121 }
00122 
00123 void TrainProcessor::doTrainEnd()
00124 {
00125         trainEnd();
00126 
00127         if (monModule) {
00128                 for(std::vector<SigBkg>::const_iterator iter =
00129                         monHistos.begin(); iter != monHistos.end(); ++iter) {
00130 
00131                         for(unsigned int i = 0; i < 2; i++) {
00132                                 Int_t oBin = iter->histo[i]->GetNbinsX() + 1;
00133                                 iter->histo[i]->SetBinContent(0,
00134                                         iter->histo[i]->GetBinContent(0) +
00135                                         iter->underflow[i]);
00136                                 iter->histo[i]->SetBinContent(oBin,
00137                                         iter->histo[i]->GetBinContent(oBin) +
00138                                         iter->overflow[i]);
00139                                 iter->histo[i]->SetEntries(iter->entries[i]);
00140                         }
00141                 }
00142 
00143                 monModule = 0;
00144         }
00145 }
00146 
00147 template<>
00148 TrainProcessor *ProcessRegistry<TrainProcessor, AtomicId,
00149                                 MVATrainer>::Factory::create(
00150                 const char *name, const AtomicId *id, MVATrainer *trainer)
00151 {
00152         TrainProcessor *result = ProcessRegistry::create(name, id, trainer);
00153         if (!result) {
00154                 // try to load the shared library and retry
00155                 try {
00156                         delete TrainProcessor::PluginFactory::get()->create(
00157                                 std::string("TrainProcessor/") + name);
00158                         result = ProcessRegistry::create(name, id, trainer);
00159                 } catch(const cms::Exception &e) {
00160                         // caller will have to deal with the null pointer
00161                         // in principle this will just give a slightly more
00162                         // descriptive error message (and will rethrow anyhow)
00163                 }
00164         }
00165         return result;
00166 }
00167 
00168 } // namespace PhysicsTools