CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_5_3_0/src/PhysicsTools/MVATrainer/src/TrainProcessor.cc

Go to the documentation of this file.
00001 #include <limits>
00002 #include <string>
00003 
00004 #include <TH1.h>
00005 
00006 #include "FWCore/PluginManager/interface/PluginManager.h"
00007 #include "FWCore/PluginManager/interface/PluginFactory.h"
00008 
00009 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00010 
00011 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00012 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00013 #include "PhysicsTools/MVAComputer/interface/ProcessRegistry.icc"
00014 
00015 EDM_REGISTER_PLUGINFACTORY(PhysicsTools::TrainProcessor::PluginFactory,
00016                            "PhysicsToolsMVATrainer");
00017 
00018 namespace PhysicsTools {
00019 
00020 TrainProcessor::TrainProcessor(const char *name,
00021                                const AtomicId *id,
00022                                MVATrainer *trainer) :
00023         Source(*id), name(name), trainer(trainer), monitoring(0), monModule(0)
00024 {
00025 }
00026 
00027 TrainProcessor::~TrainProcessor()
00028 {
00029 }
00030 
00031 void TrainProcessor::doTrainBegin()
00032 {
00033         bool booked = false;
00034         unsigned int nBins = 50;
00035 
00036         if (!monitoring) {
00037                 const char *source = getName();
00038                 if (source) {
00039                         monitoring = trainer->bookMonitor(name + "_" + source);
00040                         monModule = trainer->bookMonitor(std::string("input_") +
00041                                                          source);
00042                 } else {
00043                         monModule = trainer->bookMonitor("output");
00044                         nBins = 400;
00045                 }
00046 
00047                 booked = monModule != 0;
00048         }
00049 
00050         if (booked) {
00051                 std::vector<SourceVariable*> inputs = getInputs().get();
00052                 for(std::vector<SourceVariable*>::const_iterator iter =
00053                         inputs.begin(); iter != inputs.end(); ++iter) {
00054 
00055                         SourceVariable *var = *iter;
00056                         std::string name =
00057                                 (const char*)var->getSource()->getName()
00058                                 + std::string("_")
00059                                 + (const char*)var->getName();
00060 
00061                         SigBkg pair;
00062                         pair.entries[0] = pair.entries[1] = 0;
00063                         pair.histo[0] = monModule->book<TH1F>(name + "_bkg",
00064                                 (name + "_bkg").c_str(),
00065                                 (name + " background").c_str(), nBins, 0, 0);
00066                         pair.histo[1] = monModule->book<TH1F>(name + "_sig",
00067                                 (name + "_sig").c_str(),
00068                                 (name + " signal").c_str(), nBins, 0, 0);
00069                         pair.underflow[0] = pair.underflow[1] = 0.0;
00070                         pair.overflow[0] = pair.overflow[1] = 0.0;
00071 
00072                         pair.sameBinning = true;        // use as default
00073                         if (monitoring) {
00074                                 pair.min = -std::numeric_limits<double>::infinity();
00075                                 pair.max = +std::numeric_limits<double>::infinity();
00076                         } else {
00077                                 pair.min = -99999.0;
00078                                 pair.max = +99999.0;
00079                         }
00080 
00081                         monHistos.push_back(pair);
00082                 }
00083         }
00084 
00085         trainBegin();
00086 }
00087 
00088 void TrainProcessor::doTrainData(const std::vector<double> *values,
00089                                  bool target, double weight,
00090                                  bool train, bool test)
00091 {
00092         if (monModule && test) {
00093                 for(std::vector<SigBkg>::iterator iter = monHistos.begin();
00094                     iter != monHistos.end(); ++iter) {
00095                         const std::vector<double> &vals =
00096                                         values[iter - monHistos.begin()];
00097                         for(std::vector<double>::const_iterator value =
00098                                 vals.begin(); value != vals.end(); ++value) {
00099 
00100                                 iter->entries[target]++;
00101 
00102                                 if (*value <= iter->min) {
00103                                         iter->underflow[target] += weight;
00104                                         continue;
00105                                 } else if (*value >= iter->max) {
00106                                         iter->overflow[target] += weight;
00107                                         continue;
00108                                 }
00109 
00110                                 iter->histo[target]->Fill(*value, weight);
00111 
00112                                 if (iter->sameBinning)
00113                                         iter->histo[!target]->Fill(*value, 0);
00114                         }
00115                 }
00116         }
00117 
00118         if (train)
00119                 trainData(values, target, weight);
00120         if (test)
00121                 testData(values, target, weight, train);
00122 }
00123 
00124 void TrainProcessor::doTrainEnd()
00125 {
00126         trainEnd();
00127 
00128         if (monModule) {
00129                 for(std::vector<SigBkg>::const_iterator iter =
00130                         monHistos.begin(); iter != monHistos.end(); ++iter) {
00131 
00132                         for(unsigned int i = 0; i < 2; i++) {
00133                                 Int_t oBin = iter->histo[i]->GetNbinsX() + 1;
00134                                 iter->histo[i]->SetBinContent(0,
00135                                         iter->histo[i]->GetBinContent(0) +
00136                                         iter->underflow[i]);
00137                                 iter->histo[i]->SetBinContent(oBin,
00138                                         iter->histo[i]->GetBinContent(oBin) +
00139                                         iter->overflow[i]);
00140                                 iter->histo[i]->SetEntries(iter->entries[i]);
00141                         }
00142                 }
00143 
00144                 monModule = 0;
00145         }
00146 }
00147 
00148 template<>
00149 TrainProcessor *ProcessRegistry<TrainProcessor, AtomicId,
00150                                 MVATrainer>::Factory::create(
00151                 const char *name, const AtomicId *id, MVATrainer *trainer)
00152 {
00153         TrainProcessor *result = ProcessRegistry::create(name, id, trainer);
00154         if (!result) {
00155                 // try to load the shared library and retry
00156                 try {
00157                         delete TrainProcessor::PluginFactory::get()->create(
00158                                 std::string("TrainProcessor/") + name);
00159                         result = ProcessRegistry::create(name, id, trainer);
00160                 } catch(const cms::Exception &e) {
00161                         // caller will have to deal with the null pointer
00162                         // in principle this will just give a slightly more
00163                         // descriptive error message (and will rethrow anyhow)
00164                 }
00165         }
00166         return result;
00167 }
00168 
00169 } // namespace PhysicsTools
00170 template void PhysicsTools::ProcessRegistry<PhysicsTools::TrainProcessor, PhysicsTools::AtomicId, PhysicsTools::MVATrainer>::unregisterProcess(char const*);
00171 template void PhysicsTools::ProcessRegistry<PhysicsTools::TrainProcessor, PhysicsTools::AtomicId, PhysicsTools::MVATrainer>::registerProcess(char const*, PhysicsTools::ProcessRegistry<PhysicsTools::TrainProcessor, PhysicsTools::AtomicId, PhysicsTools::MVATrainer> const*);