CMS 3D CMS Logo

TrainProcessor.cc

Go to the documentation of this file.
00001 #include <limits>
00002 #include <string>
00003 
00004 #include <TH1.h>
00005 
00006 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00007 
00008 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00009 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00010 
00011 namespace PhysicsTools {
00012 
00013 TrainProcessor::TrainProcessor(const char *name,
00014                                const AtomicId *id,
00015                                MVATrainer *trainer) :
00016         Source(*id), name(name), trainer(trainer), monitoring(0), monModule(0)
00017 {
00018 }
00019 
00020 TrainProcessor::~TrainProcessor()
00021 {
00022 }
00023 
00024 void TrainProcessor::doTrainBegin()
00025 {
00026         bool booked = false;
00027         unsigned int nBins = 50;
00028 
00029         if (!monitoring) {
00030                 const char *source = getName();
00031                 if (source) {
00032                         monitoring = trainer->bookMonitor(name + "_" + source);
00033                         monModule = trainer->bookMonitor(std::string("input_") +
00034                                                          source);
00035                 } else {
00036                         monModule = trainer->bookMonitor("output");
00037                         nBins = 400;
00038                 }
00039 
00040                 booked = monModule != 0;
00041         }
00042 
00043         if (booked) {
00044                 std::vector<SourceVariable*> inputs = getInputs().get();
00045                 for(std::vector<SourceVariable*>::const_iterator iter =
00046                         inputs.begin(); iter != inputs.end(); ++iter) {
00047 
00048                         SourceVariable *var = *iter;
00049                         std::string name =
00050                                 (const char*)var->getSource()->getName()
00051                                 + std::string("_")
00052                                 + (const char*)var->getName();
00053 
00054                         SigBkg pair;
00055                         pair.entries[0] = pair.entries[1] = 0;
00056                         pair.histo[0] = monModule->book<TH1F>(name + "_bkg",
00057                                 (name + "_bkg").c_str(),
00058                                 (name + " background").c_str(), nBins, 0, 0);
00059                         pair.histo[1] = monModule->book<TH1F>(name + "_sig",
00060                                 (name + "_sig").c_str(),
00061                                 (name + " signal").c_str(), nBins, 0, 0);
00062                         pair.underflow[0] = pair.underflow[1] = 0.0;
00063                         pair.overflow[0] = pair.overflow[1] = 0.0;
00064 
00065                         pair.sameBinning = true;        // use as default
00066                         if (monitoring) {
00067                                 pair.min = -std::numeric_limits<double>::infinity();
00068                                 pair.max = +std::numeric_limits<double>::infinity();
00069                         } else {
00070                                 pair.min = -99999.0;
00071                                 pair.max = +99999.0;
00072                         }
00073 
00074                         monHistos.push_back(pair);
00075                 }
00076         }
00077 
00078         trainBegin();
00079 }
00080 
00081 void TrainProcessor::doTrainData(const std::vector<double> *values,
00082                                  bool target, double weight,
00083                                  bool train, bool test)
00084 {
00085         if (monModule && test) {
00086                 for(std::vector<SigBkg>::iterator iter = monHistos.begin();
00087                     iter != monHistos.end(); ++iter) {
00088                         const std::vector<double> &vals =
00089                                         values[iter - monHistos.begin()];
00090                         for(std::vector<double>::const_iterator value =
00091                                 vals.begin(); value != vals.end(); ++value) {
00092 
00093                                 iter->entries[target]++;
00094 
00095                                 if (*value <= iter->min) {
00096                                         iter->underflow[target] += weight;
00097                                         continue;
00098                                 } else if (*value >= iter->max) {
00099                                         iter->overflow[target] += weight;
00100                                         continue;
00101                                 }
00102 
00103                                 iter->histo[target]->Fill(*value, weight);
00104 
00105                                 if (iter->sameBinning)
00106                                         iter->histo[!target]->Fill(*value, 0);
00107                         }
00108                 }
00109         }
00110 
00111         if (train)
00112                 trainData(values, target, weight);
00113         if (test)
00114                 testData(values, target, weight, train);
00115 }
00116 
00117 void TrainProcessor::doTrainEnd()
00118 {
00119         trainEnd();
00120 
00121         if (monModule) {
00122                 for(std::vector<SigBkg>::const_iterator iter =
00123                         monHistos.begin(); iter != monHistos.end(); ++iter) {
00124 
00125                         for(unsigned int i = 0; i < 2; i++) {
00126                                 Int_t oBin = iter->histo[i]->GetNbinsX() + 1;
00127                                 iter->histo[i]->SetBinContent(0,
00128                                         iter->histo[i]->GetBinContent(0) +
00129                                         iter->underflow[i]);
00130                                 iter->histo[i]->SetBinContent(oBin,
00131                                         iter->histo[i]->GetBinContent(oBin) +
00132                                         iter->overflow[i]);
00133                                 iter->histo[i]->SetEntries(iter->entries[i]);
00134                         }
00135                 }
00136 
00137                 monModule = 0;
00138         }
00139 }
00140 
00141 } // namespace PhysicsTools

Generated on Tue Jun 9 17:41:32 2009 for CMSSW by  doxygen 1.5.4