00001 #include <functional>
00002 #include <algorithm>
00003 #include <string>
00004 #include <vector>
00005
00006 #include <TString.h>
00007 #include <TTree.h>
00008
00009 #include "FWCore/Utilities/interface/Exception.h"
00010
00011 #include "PhysicsTools/MVAComputer/interface/MVAComputer.h"
00012 #include "PhysicsTools/MVAComputer/interface/TreeReader.h"
00013
00014 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00015 #include "PhysicsTools/MVATrainer/interface/TreeTrainer.h"
00016
00017 namespace PhysicsTools {
00018
00019 TreeTrainer::TreeTrainer()
00020 {
00021 }
00022
00023 TreeTrainer::TreeTrainer(TTree *tree, double weight)
00024 {
00025 addTree(tree, -1, weight);
00026 }
00027
00028 TreeTrainer::TreeTrainer(TTree *signal, TTree *background, double weight)
00029 {
00030 addTree(signal, true, weight);
00031 addTree(background, false, weight);
00032 }
00033
00034 TreeTrainer::~TreeTrainer()
00035 {
00036 reset();
00037 }
00038
00039 Calibration::MVAComputer *TreeTrainer::train(const std::string &trainFile,
00040 double crossValidation)
00041 {
00042 MVATrainer trainer(trainFile);
00043 trainer.setMonitoring(true);
00044 trainer.setCrossValidation(crossValidation);
00045 train(&trainer);
00046 return trainer.getCalibration();
00047 }
00048
00049 void TreeTrainer::reset()
00050 {
00051 readers.clear();
00052 std::for_each(weights.begin(), weights.end(),
00053 std::ptr_fun(&::operator delete));
00054 weights.clear();
00055 }
00056
00057 void TreeTrainer::addTree(TTree *tree, int target, double weight)
00058 {
00059 static const bool targets[2] = { true, false };
00060
00061 TreeReader reader(tree, false, weight > 0.0);
00062
00063 if (target >= 0) {
00064 if (tree->GetBranch("__TARGET__"))
00065 throw cms::Exception("TreeTrainer")
00066 << "__TARGET__ branch already present in file."
00067 << std::endl;
00068
00069 reader.addSingle(MVATrainer::kTargetId, &targets[!target]);
00070 }
00071
00072 if (weight > 0.0) {
00073 double *ptr = new double(weight);
00074 weights.push_back(ptr);
00075 reader.addSingle(MVATrainer::kWeightId, ptr);
00076 }
00077
00078 addReader(reader);
00079 }
00080
00081 void TreeTrainer::addReader(const TreeReader &reader)
00082 {
00083 readers.push_back(reader);
00084 }
00085
00086 bool TreeTrainer::iteration(MVATrainer *trainer)
00087 {
00088 Calibration::MVAComputer *calib = trainer->getTrainCalibration();
00089 if (!calib)
00090 return true;
00091
00092 MVAComputer computer(calib, true);
00093
00094 std::for_each(readers.begin(), readers.end(),
00095 std::bind2nd(std::mem_fun_ref(&TreeReader::loop),
00096 &computer));
00097
00098 return false;
00099 }
00100
00101 void TreeTrainer::train(MVATrainer *trainer)
00102 {
00103 while(!iteration(trainer));
00104 }
00105
00106 }