Go to the documentation of this file.00001 #include <limits>
00002 #include <string>
00003
00004 #include <TH1.h>
00005
00006 #include "FWCore/PluginManager/interface/PluginManager.h"
00007 #include "FWCore/PluginManager/interface/PluginFactory.h"
00008
00009 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00010
00011 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00012 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00013
00014 EDM_REGISTER_PLUGINFACTORY(PhysicsTools::TrainProcessor::PluginFactory,
00015 "PhysicsToolsMVATrainer");
00016
00017 namespace PhysicsTools {
00018
00019 TrainProcessor::TrainProcessor(const char *name,
00020 const AtomicId *id,
00021 MVATrainer *trainer) :
00022 Source(*id), name(name), trainer(trainer), monitoring(0), monModule(0)
00023 {
00024 }
00025
00026 TrainProcessor::~TrainProcessor()
00027 {
00028 }
00029
00030 void TrainProcessor::doTrainBegin()
00031 {
00032 bool booked = false;
00033 unsigned int nBins = 50;
00034
00035 if (!monitoring) {
00036 const char *source = getName();
00037 if (source) {
00038 monitoring = trainer->bookMonitor(name + "_" + source);
00039 monModule = trainer->bookMonitor(std::string("input_") +
00040 source);
00041 } else {
00042 monModule = trainer->bookMonitor("output");
00043 nBins = 400;
00044 }
00045
00046 booked = monModule != 0;
00047 }
00048
00049 if (booked) {
00050 std::vector<SourceVariable*> inputs = getInputs().get();
00051 for(std::vector<SourceVariable*>::const_iterator iter =
00052 inputs.begin(); iter != inputs.end(); ++iter) {
00053
00054 SourceVariable *var = *iter;
00055 std::string name =
00056 (const char*)var->getSource()->getName()
00057 + std::string("_")
00058 + (const char*)var->getName();
00059
00060 SigBkg pair;
00061 pair.entries[0] = pair.entries[1] = 0;
00062 pair.histo[0] = monModule->book<TH1F>(name + "_bkg",
00063 (name + "_bkg").c_str(),
00064 (name + " background").c_str(), nBins, 0, 0);
00065 pair.histo[1] = monModule->book<TH1F>(name + "_sig",
00066 (name + "_sig").c_str(),
00067 (name + " signal").c_str(), nBins, 0, 0);
00068 pair.underflow[0] = pair.underflow[1] = 0.0;
00069 pair.overflow[0] = pair.overflow[1] = 0.0;
00070
00071 pair.sameBinning = true;
00072 if (monitoring) {
00073 pair.min = -std::numeric_limits<double>::infinity();
00074 pair.max = +std::numeric_limits<double>::infinity();
00075 } else {
00076 pair.min = -99999.0;
00077 pair.max = +99999.0;
00078 }
00079
00080 monHistos.push_back(pair);
00081 }
00082 }
00083
00084 trainBegin();
00085 }
00086
00087 void TrainProcessor::doTrainData(const std::vector<double> *values,
00088 bool target, double weight,
00089 bool train, bool test)
00090 {
00091 if (monModule && test) {
00092 for(std::vector<SigBkg>::iterator iter = monHistos.begin();
00093 iter != monHistos.end(); ++iter) {
00094 const std::vector<double> &vals =
00095 values[iter - monHistos.begin()];
00096 for(std::vector<double>::const_iterator value =
00097 vals.begin(); value != vals.end(); ++value) {
00098
00099 iter->entries[target]++;
00100
00101 if (*value <= iter->min) {
00102 iter->underflow[target] += weight;
00103 continue;
00104 } else if (*value >= iter->max) {
00105 iter->overflow[target] += weight;
00106 continue;
00107 }
00108
00109 iter->histo[target]->Fill(*value, weight);
00110
00111 if (iter->sameBinning)
00112 iter->histo[!target]->Fill(*value, 0);
00113 }
00114 }
00115 }
00116
00117 if (train)
00118 trainData(values, target, weight);
00119 if (test)
00120 testData(values, target, weight, train);
00121 }
00122
00123 void TrainProcessor::doTrainEnd()
00124 {
00125 trainEnd();
00126
00127 if (monModule) {
00128 for(std::vector<SigBkg>::const_iterator iter =
00129 monHistos.begin(); iter != monHistos.end(); ++iter) {
00130
00131 for(unsigned int i = 0; i < 2; i++) {
00132 Int_t oBin = iter->histo[i]->GetNbinsX() + 1;
00133 iter->histo[i]->SetBinContent(0,
00134 iter->histo[i]->GetBinContent(0) +
00135 iter->underflow[i]);
00136 iter->histo[i]->SetBinContent(oBin,
00137 iter->histo[i]->GetBinContent(oBin) +
00138 iter->overflow[i]);
00139 iter->histo[i]->SetEntries(iter->entries[i]);
00140 }
00141 }
00142
00143 monModule = 0;
00144 }
00145 }
00146
00147 template<>
00148 TrainProcessor *ProcessRegistry<TrainProcessor, AtomicId,
00149 MVATrainer>::Factory::create(
00150 const char *name, const AtomicId *id, MVATrainer *trainer)
00151 {
00152 TrainProcessor *result = ProcessRegistry::create(name, id, trainer);
00153 if (!result) {
00154
00155 try {
00156 delete TrainProcessor::PluginFactory::get()->create(
00157 std::string("TrainProcessor/") + name);
00158 result = ProcessRegistry::create(name, id, trainer);
00159 } catch(const cms::Exception &e) {
00160
00161
00162
00163 }
00164 }
00165 return result;
00166 }
00167
00168 }