CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
ElectronMVAEstimatorRun2.cc
Go to the documentation of this file.
2 
8 
11  mvaVarMngr_{conf.getParameter<std::string>("variableDefinition"), MVAVariableHelper::indexMap()} {
12  const auto weightFileNames = conf.getParameter<std::vector<std::string> >("weightFileNames");
13  const auto categoryCutStrings = conf.getParameter<std::vector<std::string> >("categoryCuts");
14 
15  if ((int)(categoryCutStrings.size()) != getNCategories())
16  throw cms::Exception("MVA config failure: ")
17  << "wrong number of category cuts in ElectronMVAEstimatorRun2" << getTag() << std::endl;
18 
19  for (int i = 0; i < getNCategories(); ++i) {
20  categoryFunctions_.emplace_back(categoryCutStrings[i]);
21  }
22 
23  // Initialize GBRForests from weight files
24  init(weightFileNames);
25 }
26 
28  const std::string& mvaName,
29  int nCategories,
30  const std::string& variableDefinition,
31  const std::vector<std::string>& categoryCutStrings,
32  const std::vector<std::string>& weightFileNames,
33  bool debug)
34  : AnyMVAEstimatorRun2Base(mvaName, mvaTag, nCategories, debug),
35  mvaVarMngr_{variableDefinition, MVAVariableHelper::indexMap()} {
36  if ((int)(categoryCutStrings.size()) != getNCategories())
37  throw cms::Exception("MVA config failure: ")
38  << "wrong number of category cuts in " << getName() << getTag() << std::endl;
39 
40  for (auto const& cut : categoryCutStrings)
41  categoryFunctions_.emplace_back(cut);
42  init(weightFileNames);
43 }
44 
45 void ElectronMVAEstimatorRun2::init(const std::vector<std::string>& weightFileNames) {
46  if (isDebug()) {
47  std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
48  }
49 
50  // Initialize GBRForests
51  if ((int)(weightFileNames.size()) != getNCategories())
52  throw cms::Exception("MVA config failure: ")
53  << "wrong number of weightfiles in ElectronMVAEstimatorRun2" << getTag() << std::endl;
54 
55  // Create a TMVA reader object for each category
56  for (int i = 0; i < getNCategories(); i++) {
57  std::vector<int> variablesInCategory;
58 
59  // Use unique_ptr so that all readers are properly cleaned up
60  // when the vector clear() is called in the destructor
61 
62  std::vector<std::string> variableNamesInCategory;
63  gbrForests_.push_back(createGBRForest(weightFileNames[i], variableNamesInCategory));
64 
65  nVariables_.push_back(variableNamesInCategory.size());
66 
67  variables_.push_back(variablesInCategory);
68 
69  if (isDebug()) {
70  std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
71  std::cout << " category " << i << " with nVariables " << nVariables_[i] << std::endl;
72  }
73 
74  for (int j = 0; j < nVariables_[i]; ++j) {
75  int index = mvaVarMngr_.getVarIndex(variableNamesInCategory[j]);
76  if (index == -1) {
77  throw cms::Exception("MVA config failure: ")
78  << "Concerning ElectronMVAEstimatorRun2" << getTag() << std::endl
79  << "Variable " << variableNamesInCategory[j] << " not found in variable definition file!" << std::endl;
80  }
81  variables_[i].push_back(index);
82  }
83  }
84 }
85 
87  const std::vector<float>& auxVariables,
88  int& iCategory) const {
89  const reco::GsfElectron* electron = dynamic_cast<const reco::GsfElectron*>(candidate);
90 
91  if (electron == nullptr) {
92  throw cms::Exception("MVA failure: ")
93  << " given particle is expected to be reco::GsfElectron or pat::Electron," << std::endl
94  << " but appears to be neither" << std::endl;
95  }
96 
97  iCategory = findCategory(electron);
98 
99  if (iCategory < 0)
100  return -999;
101 
102  std::vector<float> vars;
103 
104  vars.reserve(nVariables_[iCategory]);
105  for (int i = 0; i < nVariables_[iCategory]; ++i) {
106  vars.push_back(mvaVarMngr_.getValue(variables_[iCategory][i], *electron, auxVariables));
107  }
108 
109  if (isDebug()) {
110  std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
111  std::cout << " category " << iCategory << std::endl;
112  for (int i = 0; i < nVariables_[iCategory]; ++i) {
113  std::cout << " " << mvaVarMngr_.getName(variables_[iCategory][i]) << " " << vars[i] << std::endl;
114  }
115  }
116  const float response = gbrForests_.at(iCategory)->GetResponse(vars.data()); // The BDT score
117 
118  if (isDebug()) {
119  std::cout << " ### MVA " << response << std::endl << std::endl;
120  }
121 
122  return response;
123 }
124 
126  const reco::GsfElectron* electron = dynamic_cast<reco::GsfElectron const*>(candidate);
127 
128  if (electron == nullptr) {
129  throw cms::Exception("MVA failure: ")
130  << " given particle is expected to be reco::GsfElectron or pat::Electron," << std::endl
131  << " but appears to be neither" << std::endl;
132  }
133 
134  return findCategory(*electron);
135 }
136 
138  for (int i = 0; i < getNCategories(); ++i) {
139  if (categoryFunctions_[i](electron))
140  return i;
141  }
142 
143  edm::LogWarning("MVA warning") << "category not defined for particle with pt " << electron.pt() << " GeV, eta "
144  << electron.superCluster()->eta() << " in ElectronMVAEstimatorRun2" << getTag();
145 
146  return -1;
147 }
const std::string & getName(int index) const
double pt() const final
transverse momentum
MVAVariableManager< reco::GsfElectron > mvaVarMngr_
int init
Definition: HydjetWrapper.h:64
float getValue(int index, const ParticleType &particle, const std::vector< float > &auxVariables) const
const std::string & getTag() const
std::vector< ThreadSafeFunctor< StringCutObjectSelector< reco::GsfElectron > > > categoryFunctions_
ElectronMVAEstimatorRun2(const edm::ParameterSet &conf)
void init(const std::vector< std::string > &weightFileNames)
std::vector< std::vector< int > > variables_
TString getName(TString structure, int layer, TString geometry)
Definition: DMRtrends.cc:236
#define debug
Definition: HDRShower.cc:19
int getVarIndex(const std::string &name)
T getParameter(std::string const &) const
Definition: ParameterSet.h:303
float mvaValue(const reco::Candidate *candidate, std::vector< float > const &auxVariables, int &iCategory) const override
std::vector< std::unique_ptr< const GBRForest > > gbrForests_
tuple cout
Definition: gather_cfg.py:144
int findCategory(const reco::Candidate *candidate) const override
vars
Definition: DeepTauId.cc:164
Log< level::Warning, false > LogWarning
SuperClusterRef superCluster() const override
reference to a SuperCluster
Definition: GsfElectron.h:155
std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightsFile)