00001 #include <limits>
00002 #include <string>
00003
00004 #include <TH1.h>
00005
00006 #include "FWCore/PluginManager/interface/PluginManager.h"
00007 #include "FWCore/PluginManager/interface/PluginFactory.h"
00008
00009 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00010
00011 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00012 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00013 #include "PhysicsTools/MVAComputer/interface/ProcessRegistry.icc"
00014
00015 EDM_REGISTER_PLUGINFACTORY(PhysicsTools::TrainProcessor::PluginFactory,
00016 "PhysicsToolsMVATrainer");
00017
00018 namespace PhysicsTools {
00019
00020 TrainProcessor::TrainProcessor(const char *name,
00021 const AtomicId *id,
00022 MVATrainer *trainer) :
00023 Source(*id), name(name), trainer(trainer), monitoring(0), monModule(0)
00024 {
00025 }
00026
00027 TrainProcessor::~TrainProcessor()
00028 {
00029 }
00030
00031 void TrainProcessor::doTrainBegin()
00032 {
00033 bool booked = false;
00034 unsigned int nBins = 50;
00035
00036 if (!monitoring) {
00037 const char *source = getName();
00038 if (source) {
00039 monitoring = trainer->bookMonitor(name + "_" + source);
00040 monModule = trainer->bookMonitor(std::string("input_") +
00041 source);
00042 } else {
00043 monModule = trainer->bookMonitor("output");
00044 nBins = 400;
00045 }
00046
00047 booked = monModule != 0;
00048 }
00049
00050 if (booked) {
00051 std::vector<SourceVariable*> inputs = getInputs().get();
00052 for(std::vector<SourceVariable*>::const_iterator iter =
00053 inputs.begin(); iter != inputs.end(); ++iter) {
00054
00055 SourceVariable *var = *iter;
00056 std::string name =
00057 (const char*)var->getSource()->getName()
00058 + std::string("_")
00059 + (const char*)var->getName();
00060
00061 SigBkg pair;
00062 pair.entries[0] = pair.entries[1] = 0;
00063 pair.histo[0] = monModule->book<TH1F>(name + "_bkg",
00064 (name + "_bkg").c_str(),
00065 (name + " background").c_str(), nBins, 0, 0);
00066 pair.histo[1] = monModule->book<TH1F>(name + "_sig",
00067 (name + "_sig").c_str(),
00068 (name + " signal").c_str(), nBins, 0, 0);
00069 pair.underflow[0] = pair.underflow[1] = 0.0;
00070 pair.overflow[0] = pair.overflow[1] = 0.0;
00071
00072 pair.sameBinning = true;
00073 if (monitoring) {
00074 pair.min = -std::numeric_limits<double>::infinity();
00075 pair.max = +std::numeric_limits<double>::infinity();
00076 } else {
00077 pair.min = -99999.0;
00078 pair.max = +99999.0;
00079 }
00080
00081 monHistos.push_back(pair);
00082 }
00083 }
00084
00085 trainBegin();
00086 }
00087
00088 void TrainProcessor::doTrainData(const std::vector<double> *values,
00089 bool target, double weight,
00090 bool train, bool test)
00091 {
00092 if (monModule && test) {
00093 for(std::vector<SigBkg>::iterator iter = monHistos.begin();
00094 iter != monHistos.end(); ++iter) {
00095 const std::vector<double> &vals =
00096 values[iter - monHistos.begin()];
00097 for(std::vector<double>::const_iterator value =
00098 vals.begin(); value != vals.end(); ++value) {
00099
00100 iter->entries[target]++;
00101
00102 if (*value <= iter->min) {
00103 iter->underflow[target] += weight;
00104 continue;
00105 } else if (*value >= iter->max) {
00106 iter->overflow[target] += weight;
00107 continue;
00108 }
00109
00110 iter->histo[target]->Fill(*value, weight);
00111
00112 if (iter->sameBinning)
00113 iter->histo[!target]->Fill(*value, 0);
00114 }
00115 }
00116 }
00117
00118 if (train)
00119 trainData(values, target, weight);
00120 if (test)
00121 testData(values, target, weight, train);
00122 }
00123
00124 void TrainProcessor::doTrainEnd()
00125 {
00126 trainEnd();
00127
00128 if (monModule) {
00129 for(std::vector<SigBkg>::const_iterator iter =
00130 monHistos.begin(); iter != monHistos.end(); ++iter) {
00131
00132 for(unsigned int i = 0; i < 2; i++) {
00133 Int_t oBin = iter->histo[i]->GetNbinsX() + 1;
00134 iter->histo[i]->SetBinContent(0,
00135 iter->histo[i]->GetBinContent(0) +
00136 iter->underflow[i]);
00137 iter->histo[i]->SetBinContent(oBin,
00138 iter->histo[i]->GetBinContent(oBin) +
00139 iter->overflow[i]);
00140 iter->histo[i]->SetEntries(iter->entries[i]);
00141 }
00142 }
00143
00144 monModule = 0;
00145 }
00146 }
00147
00148 template<>
00149 TrainProcessor *ProcessRegistry<TrainProcessor, AtomicId,
00150 MVATrainer>::Factory::create(
00151 const char *name, const AtomicId *id, MVATrainer *trainer)
00152 {
00153 TrainProcessor *result = ProcessRegistry::create(name, id, trainer);
00154 if (!result) {
00155
00156 try {
00157 delete TrainProcessor::PluginFactory::get()->create(
00158 std::string("TrainProcessor/") + name);
00159 result = ProcessRegistry::create(name, id, trainer);
00160 } catch(const cms::Exception &e) {
00161
00162
00163
00164 }
00165 }
00166 return result;
00167 }
00168
00169 }
00170 template void PhysicsTools::ProcessRegistry<PhysicsTools::TrainProcessor, PhysicsTools::AtomicId, PhysicsTools::MVATrainer>::unregisterProcess(char const*);
00171 template void PhysicsTools::ProcessRegistry<PhysicsTools::TrainProcessor, PhysicsTools::AtomicId, PhysicsTools::MVATrainer>::registerProcess(char const*, PhysicsTools::ProcessRegistry<PhysicsTools::TrainProcessor, PhysicsTools::AtomicId, PhysicsTools::MVATrainer> const*);