CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_5_2_7_hltpatch2/src/PhysicsTools/MVATrainer/src/TreeTrainer.cc

Go to the documentation of this file.
00001 #include <functional>
00002 #include <algorithm>
00003 #include <string>
00004 #include <vector>
00005 
00006 #include <TString.h>
00007 #include <TTree.h>
00008 
00009 #include "FWCore/Utilities/interface/Exception.h"
00010 
00011 #include "PhysicsTools/MVAComputer/interface/MVAComputer.h"
00012 #include "PhysicsTools/MVAComputer/interface/TreeReader.h"
00013 
00014 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00015 #include "PhysicsTools/MVATrainer/interface/TreeTrainer.h"
00016 
00017 namespace PhysicsTools {
00018 
00019 TreeTrainer::TreeTrainer()
00020 {
00021 }
00022 
00023 TreeTrainer::TreeTrainer(TTree *tree, double weight)
00024 {
00025         addTree(tree, -1, weight);
00026 }
00027 
00028 TreeTrainer::TreeTrainer(TTree *signal, TTree *background, double weight)
00029 {
00030         addTree(signal, true, weight);
00031         addTree(background, false, weight);
00032 }
00033 
00034 TreeTrainer::~TreeTrainer()
00035 {
00036         reset();
00037 }
00038 
00039 Calibration::MVAComputer *TreeTrainer::train(const std::string &trainFile,
00040                                              double crossValidation,
00041                                              bool useXSLT)
00042 {
00043         MVATrainer trainer(trainFile, useXSLT);
00044         trainer.setMonitoring(true);
00045         trainer.setCrossValidation(crossValidation);
00046         train(&trainer);
00047         return trainer.getCalibration();
00048 }
00049 
00050 void TreeTrainer::reset()
00051 {
00052         readers.clear();
00053         std::for_each(weights.begin(), weights.end(),
00054                       std::ptr_fun(&::operator delete));
00055         weights.clear();
00056 }
00057 
00058 void TreeTrainer::addTree(TTree *tree, int target, double weight)
00059 {
00060         static const bool targets[2] = { true, false };
00061 
00062         TreeReader reader(tree, false, weight > 0.0);
00063 
00064         if (target >= 0) {
00065                 if (tree->GetBranch("__TARGET__"))
00066                         throw cms::Exception("TreeTrainer")
00067                                 << "__TARGET__ branch already present in file."
00068                                 << std::endl;
00069 
00070                 reader.addSingle(MVATrainer::kTargetId, &targets[!target]);
00071         }
00072 
00073         if (weight > 0.0) {
00074                 double *ptr = new double(weight);
00075                 weights.push_back(ptr);
00076                 reader.addSingle(MVATrainer::kWeightId, ptr);
00077         }
00078 
00079         addReader(reader);
00080 }
00081 
00082 void TreeTrainer::addReader(const TreeReader &reader)
00083 {
00084         readers.push_back(reader);
00085 }
00086 
00087 bool TreeTrainer::iteration(MVATrainer *trainer)
00088 {
00089         Calibration::MVAComputer *calib = trainer->getTrainCalibration();   
00090         if (!calib)
00091                 return true;
00092 
00093         MVAComputer computer(calib, true);
00094 
00095         std::for_each(readers.begin(), readers.end(),
00096                       std::bind2nd(std::mem_fun_ref(&TreeReader::loop),
00097                                    &computer));
00098 
00099         return false;
00100 }
00101 
00102 void TreeTrainer::train(MVATrainer *trainer)
00103 {
00104         while(!iteration(trainer));
00105 }
00106 
00107 } // namespace PhysicsTools