Go to the documentation of this file.00001 #include <functional>
00002 #include <algorithm>
00003 #include <string>
00004 #include <vector>
00005
00006 #include <TString.h>
00007 #include <TTree.h>
00008
00009 #include "FWCore/Utilities/interface/Exception.h"
00010
00011 #include "PhysicsTools/MVAComputer/interface/MVAComputer.h"
00012 #include "PhysicsTools/MVAComputer/interface/TreeReader.h"
00013
00014 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00015 #include "PhysicsTools/MVATrainer/interface/TreeTrainer.h"
00016
00017 namespace PhysicsTools {
00018
00019 TreeTrainer::TreeTrainer()
00020 {
00021 }
00022
00023 TreeTrainer::TreeTrainer(TTree *tree, double weight)
00024 {
00025 addTree(tree, -1, weight);
00026 }
00027
00028 TreeTrainer::TreeTrainer(TTree *signal, TTree *background, double weight)
00029 {
00030 addTree(signal, true, weight);
00031 addTree(background, false, weight);
00032 }
00033
00034 TreeTrainer::~TreeTrainer()
00035 {
00036 reset();
00037 }
00038
00039 Calibration::MVAComputer *TreeTrainer::train(const std::string &trainFile,
00040 double crossValidation,
00041 bool useXSLT)
00042 {
00043 MVATrainer trainer(trainFile, useXSLT);
00044 trainer.setMonitoring(true);
00045 trainer.setCrossValidation(crossValidation);
00046 train(&trainer);
00047 return trainer.getCalibration();
00048 }
00049
00050 void TreeTrainer::reset()
00051 {
00052 readers.clear();
00053 std::for_each(weights.begin(), weights.end(),
00054 std::ptr_fun(&::operator delete));
00055 weights.clear();
00056 }
00057
00058 void TreeTrainer::addTree(TTree *tree, int target, double weight)
00059 {
00060 static const bool targets[2] = { true, false };
00061
00062 TreeReader reader(tree, false, weight > 0.0);
00063
00064 if (target >= 0) {
00065 if (tree->GetBranch("__TARGET__"))
00066 throw cms::Exception("TreeTrainer")
00067 << "__TARGET__ branch already present in file."
00068 << std::endl;
00069
00070 reader.addSingle(MVATrainer::kTargetId, &targets[!target]);
00071 }
00072
00073 if (weight > 0.0) {
00074 double *ptr = new double(weight);
00075 weights.push_back(ptr);
00076 reader.addSingle(MVATrainer::kWeightId, ptr);
00077 }
00078
00079 addReader(reader);
00080 }
00081
00082 void TreeTrainer::addReader(const TreeReader &reader)
00083 {
00084 readers.push_back(reader);
00085 }
00086
00087 bool TreeTrainer::iteration(MVATrainer *trainer)
00088 {
00089 Calibration::MVAComputer *calib = trainer->getTrainCalibration();
00090 if (!calib)
00091 return true;
00092
00093 MVAComputer computer(calib, true);
00094
00095 std::for_each(readers.begin(), readers.end(),
00096 std::bind2nd(std::mem_fun_ref(&TreeReader::loop),
00097 &computer));
00098
00099 return false;
00100 }
00101
00102 void TreeTrainer::train(MVATrainer *trainer)
00103 {
00104 while(!iteration(trainer));
00105 }
00106
00107 }