CMS 3D CMS Logo

ElectronMVAEstimatorRun2.cc
Go to the documentation of this file.
2 
5  name_(conf.getParameter<std::string>("mvaName")),
6  tag_(conf.getParameter<std::string>("mvaTag")),
7  nCategories_ (conf.getParameter<int>("nCategories")),
8  methodName_ ("BDTG method"),
9  mvaVarMngr_(conf.getParameter<std::string>("variableDefinition")),
10  debug_(conf.getUntrackedParameter<bool>("debug", false))
11 {
12 
13  const std::vector <std::string> weightFileNames
14  = conf.getParameter<std::vector<std::string> >("weightFileNames");
15 
16  const std::vector <std::string> categoryCutStrings
17  = conf.getParameter<std::vector<std::string> >("categoryCuts");
18 
19  if( (int)(categoryCutStrings.size()) != nCategories_ )
20  throw cms::Exception("MVA config failure: ")
21  << "wrong number of category cuts in " << name_ << tag_ << std::endl;
22 
23  for (int i = 0; i < nCategories_; ++i) {
25  categoryFunctions_.push_back(select);
26  }
27 
28  // Initialize GBRForests from weight files
29  init(weightFileNames);
30 
31 }
32 
34  const std::string &mvaTag, const std::string &mvaName, const bool debug):
36  name_ (mvaName),
37  tag_ (mvaTag),
38  methodName_ ("BDTG method"),
39  debug_ (debug) {
40  }
41 
42 void ElectronMVAEstimatorRun2::init(const std::vector<std::string> &weightFileNames) {
43 
44  if(debug_) {
45  std::cout << " *** Inside " << name_ << tag_ << std::endl;
46  }
47 
48  // Initialize GBRForests
49  if( (int)(weightFileNames.size()) != nCategories_ )
50  throw cms::Exception("MVA config failure: ")
51  << "wrong number of weightfiles in " << name_ << tag_ << std::endl;
52 
53  gbrForests_.clear();
54  // Create a TMVA reader object for each category
55  for(int i=0; i<nCategories_; i++){
56 
57  std::vector<std::string> variableNamesInCategory;
58  std::vector<int> variablesInCategory;
59 
60  // Use unique_ptr so that all readers are properly cleaned up
61  // when the vector clear() is called in the destructor
62 
63  gbrForests_.push_back( GBRForestTools::createGBRForest( weightFileNames[i], variableNamesInCategory ) );
64 
65  nVariables_.push_back(variableNamesInCategory.size());
66 
67  variables_.push_back(variablesInCategory);
68 
69  if(debug_) {
70  std::cout << " *** Inside " << name_ << tag_ << std::endl;
71  std::cout << " category " << i << " with nVariables " << nVariables_[i] << std::endl;
72  }
73 
74  for (int j=0; j<nVariables_[i];++j) {
75  int index = mvaVarMngr_.getVarIndex(variableNamesInCategory[j]);
76  if(index == -1) {
77  throw cms::Exception("MVA config failure: ")
78  << "Concerning " << name_ << tag_ << std::endl
79  << "Variable " << variableNamesInCategory[j]
80  << " not found in variable definition file!" << std::endl;
81  }
82  variables_[i].push_back(index);
83 
84  }
85  }
86 }
87 
90 }
91 
93  // All tokens for event content needed by this MVA
94  // Tags from the variable helper
95  for (auto &tag : mvaVarMngr_.getHelperInputTags()) {
96  cc.consumes<edm::ValueMap<float>>(tag);
97  }
98  for (auto &tag : mvaVarMngr_.getGlobalInputTags()) {
99  cc.consumes<double>(tag);
100  }
101 }
102 
104 mvaValue( const edm::Ptr<reco::Candidate>& candPtr, const edm::EventBase & iEvent) const {
105 
106  const int iCategory = findCategory( candPtr );
107 
108  if (iCategory < 0) return -999;
109 
110  std::vector<float> vars;
111 
112  const edm::Ptr<reco::GsfElectron> gsfPtr{ candPtr };
113 
114  if( gsfPtr.get() == nullptr ) {
115  throw cms::Exception("MVA failure: ")
116  << " given particle is expected to be reco::GsfElectron or pat::Electron," << std::endl
117  << " but appears to be neither" << std::endl;
118  }
119 
120  for (int i = 0; i < nVariables_[iCategory]; ++i) {
121  vars.push_back(mvaVarMngr_.getValue(variables_[iCategory][i], gsfPtr, iEvent));
122  }
123 
124  if(debug_) {
125  std::cout << " *** Inside " << name_ << tag_ << std::endl;
126  std::cout << " category " << iCategory << std::endl;
127  for (int i = 0; i < nVariables_[iCategory]; ++i) {
128  std::cout << " " << mvaVarMngr_.getName(variables_[iCategory][i]) << " " << vars[i] << std::endl;
129  }
130  }
131  const float response = gbrForests_.at(iCategory)->GetResponse(vars.data()); // The BDT score
132 
133  if(debug_) {
134  std::cout << " ### MVA " << response << std::endl << std::endl;
135  }
136 
137  return response;
138 }
139 
141 
142  auto gsfEle = dynamic_cast<reco::GsfElectron const*>(candPtr.get());
143 
144  if( gsfEle == nullptr ) {
145  throw cms::Exception("MVA failure: ")
146  << " given particle is expected to be reco::GsfElectron or pat::Electron," << std::endl
147  << " but appears to be neither" << std::endl;
148  }
149 
150  for (int i = 0; i < nCategories_; ++i) {
151  if (categoryFunctions_[i](*gsfEle)) return i;
152  }
153 
154  edm::LogWarning ("MVA warning") <<
155  "category not defined for particle with pt " << gsfEle->pt() << " GeV, eta " <<
156  gsfEle->superCluster()->eta() << " in " << name_ << tag_;
157 
158  return -1;
159 
160 }
const std::string getName(int index) const
T getParameter(std::string const &) const
int findCategory(const edm::Ptr< reco::Candidate > &candPtr) const override
T const * get() const
Returns C++ pointer to the item.
Definition: Ptr.h:159
MVAVariableManager< reco::GsfElectron > mvaVarMngr_
std::vector< edm::InputTag > getHelperInputTags() const
#define nullptr
void setConsumes(edm::ConsumesCollector &&) const final
int iEvent
Definition: GenABIO.cc:230
std::vector< StringCutObjectSelector< reco::GsfElectron > > categoryFunctions_
static std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightFile)
ElectronMVAEstimatorRun2(const edm::ParameterSet &conf)
void init(const std::vector< std::string > &weightFileNames)
int getVarIndex(std::string &name)
std::vector< std::vector< int > > variables_
std::vector< std::unique_ptr< const GBRForest > > gbrForests_
#define debug
Definition: HDRShower.cc:19
float getValue(int index, const edm::Ptr< ParticleType > &ptclPtr, const edm::EventBase &iEvent) const
HLT enums.
float mvaValue(const edm::Ptr< reco::Candidate > &candPtr, const edm::EventBase &iEvent) const override
std::vector< edm::InputTag > getGlobalInputTags() const