CMS 3D CMS Logo

TreeTrainer.cc

Go to the documentation of this file.
00001 #include <functional>
00002 #include <algorithm>
00003 #include <string>
00004 #include <vector>
00005 
00006 #include <TString.h>
00007 #include <TTree.h>
00008 
00009 #include "FWCore/Utilities/interface/Exception.h"
00010 
00011 #include "PhysicsTools/MVAComputer/interface/MVAComputer.h"
00012 #include "PhysicsTools/MVAComputer/interface/TreeReader.h"
00013 
00014 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00015 #include "PhysicsTools/MVATrainer/interface/TreeTrainer.h"
00016 
00017 namespace PhysicsTools {
00018 
00019 TreeTrainer::TreeTrainer()
00020 {
00021 }
00022 
00023 TreeTrainer::TreeTrainer(TTree *tree, double weight)
00024 {
00025         addTree(tree, -1, weight);
00026 }
00027 
00028 TreeTrainer::TreeTrainer(TTree *signal, TTree *background, double weight)
00029 {
00030         addTree(signal, true, weight);
00031         addTree(background, false, weight);
00032 }
00033 
00034 TreeTrainer::~TreeTrainer()
00035 {
00036         reset();
00037 }
00038 
00039 Calibration::MVAComputer *TreeTrainer::train(const std::string &trainFile,
00040                                              double crossValidation)
00041 {
00042         MVATrainer trainer(trainFile);
00043         trainer.setMonitoring(true);
00044         trainer.setCrossValidation(crossValidation);
00045         train(&trainer);
00046         return trainer.getCalibration();
00047 }
00048 
00049 void TreeTrainer::reset()
00050 {
00051         readers.clear();
00052         std::for_each(weights.begin(), weights.end(),
00053                       std::ptr_fun(&::operator delete));
00054         weights.clear();
00055 }
00056 
00057 void TreeTrainer::addTree(TTree *tree, int target, double weight)
00058 {
00059         static const bool targets[2] = { true, false };
00060 
00061         TreeReader reader(tree, false, weight > 0.0);
00062 
00063         if (target >= 0) {
00064                 if (tree->GetBranch("__TARGET__"))
00065                         throw cms::Exception("TreeTrainer")
00066                                 << "__TARGET__ branch already present in file."
00067                                 << std::endl;
00068 
00069                 reader.addSingle(MVATrainer::kTargetId, &targets[!target]);
00070         }
00071 
00072         if (weight > 0.0) {
00073                 double *ptr = new double(weight);
00074                 weights.push_back(ptr);
00075                 reader.addSingle(MVATrainer::kWeightId, ptr);
00076         }
00077 
00078         addReader(reader);
00079 }
00080 
00081 void TreeTrainer::addReader(const TreeReader &reader)
00082 {
00083         readers.push_back(reader);
00084 }
00085 
00086 bool TreeTrainer::iteration(MVATrainer *trainer)
00087 {
00088         Calibration::MVAComputer *calib = trainer->getTrainCalibration();   
00089         if (!calib)
00090                 return true;
00091 
00092         MVAComputer computer(calib, true);
00093 
00094         std::for_each(readers.begin(), readers.end(),
00095                       std::bind2nd(std::mem_fun_ref(&TreeReader::loop),
00096                                    &computer));
00097 
00098         return false;
00099 }
00100 
00101 void TreeTrainer::train(MVATrainer *trainer)
00102 {
00103         while(!iteration(trainer));
00104 }
00105 
00106 } // namespace PhysicsTools

Generated on Tue Jun 9 17:41:32 2009 for CMSSW by  doxygen 1.5.4