CMS 3D CMS Logo

ElectronMVAEstimatorRun2.cc
Go to the documentation of this file.
2 
4  : AnyMVAEstimatorRun2Base(conf), mvaVarMngr_(conf.getParameter<std::string>("variableDefinition")) {
5  const auto weightFileNames = conf.getParameter<std::vector<std::string> >("weightFileNames");
6  const auto categoryCutStrings = conf.getParameter<std::vector<std::string> >("categoryCuts");
7 
8  if ((int)(categoryCutStrings.size()) != getNCategories())
9  throw cms::Exception("MVA config failure: ")
10  << "wrong number of category cuts in ElectronMVAEstimatorRun2" << getTag() << std::endl;
11 
12  for (int i = 0; i < getNCategories(); ++i) {
13  categoryFunctions_.emplace_back(categoryCutStrings[i]);
14  }
15 
16  // Initialize GBRForests from weight files
18 }
19 
21  const std::string& mvaName,
22  int nCategories,
24  const std::vector<std::string>& categoryCutStrings,
25  const std::vector<std::string>& weightFileNames,
26  bool debug)
27  : AnyMVAEstimatorRun2Base(mvaName, mvaTag, nCategories, debug), mvaVarMngr_(variableDefinition) {
28  if ((int)(categoryCutStrings.size()) != getNCategories())
29  throw cms::Exception("MVA config failure: ")
30  << "wrong number of category cuts in " << getName() << getTag() << std::endl;
31 
32  for (auto const& cut : categoryCutStrings)
33  categoryFunctions_.emplace_back(cut);
34  init(weightFileNames);
35 }
36 
37 void ElectronMVAEstimatorRun2::init(const std::vector<std::string>& weightFileNames) {
38  if (isDebug()) {
39  std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
40  }
41 
42  // Initialize GBRForests
43  if ((int)(weightFileNames.size()) != getNCategories())
44  throw cms::Exception("MVA config failure: ")
45  << "wrong number of weightfiles in ElectronMVAEstimatorRun2" << getTag() << std::endl;
46 
47  gbrForests_.clear();
48  // Create a TMVA reader object for each category
49  for (int i = 0; i < getNCategories(); i++) {
50  std::vector<int> variablesInCategory;
51 
52  // Use unique_ptr so that all readers are properly cleaned up
53  // when the vector clear() is called in the destructor
54 
55  std::vector<std::string> variableNamesInCategory;
56  gbrForests_.push_back(createGBRForest(weightFileNames[i], variableNamesInCategory));
57 
58  nVariables_.push_back(variableNamesInCategory.size());
59 
60  variables_.push_back(variablesInCategory);
61 
62  if (isDebug()) {
63  std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
64  std::cout << " category " << i << " with nVariables " << nVariables_[i] << std::endl;
65  }
66 
67  for (int j = 0; j < nVariables_[i]; ++j) {
68  int index = mvaVarMngr_.getVarIndex(variableNamesInCategory[j]);
69  if (index == -1) {
70  throw cms::Exception("MVA config failure: ")
71  << "Concerning ElectronMVAEstimatorRun2" << getTag() << std::endl
72  << "Variable " << variableNamesInCategory[j] << " not found in variable definition file!" << std::endl;
73  }
74  variables_[i].push_back(index);
75  }
76  }
77 }
78 
80  const std::vector<float>& auxVariables,
81  int& iCategory) const {
82  const reco::GsfElectron* electron = dynamic_cast<const reco::GsfElectron*>(candidate);
83 
84  if (electron == nullptr) {
85  throw cms::Exception("MVA failure: ")
86  << " given particle is expected to be reco::GsfElectron or pat::Electron," << std::endl
87  << " but appears to be neither" << std::endl;
88  }
89 
90  iCategory = findCategory(electron);
91 
92  if (iCategory < 0)
93  return -999;
94 
95  std::vector<float> vars;
96 
97  for (int i = 0; i < nVariables_[iCategory]; ++i) {
98  vars.push_back(mvaVarMngr_.getValue(variables_[iCategory][i], *electron, auxVariables));
99  }
100 
101  if (isDebug()) {
102  std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
103  std::cout << " category " << iCategory << std::endl;
104  for (int i = 0; i < nVariables_[iCategory]; ++i) {
105  std::cout << " " << mvaVarMngr_.getName(variables_[iCategory][i]) << " " << vars[i] << std::endl;
106  }
107  }
108  const float response = gbrForests_.at(iCategory)->GetResponse(vars.data()); // The BDT score
109 
110  if (isDebug()) {
111  std::cout << " ### MVA " << response << std::endl << std::endl;
112  }
113 
114  return response;
115 }
116 
118  const reco::GsfElectron* electron = dynamic_cast<reco::GsfElectron const*>(candidate);
119 
120  if (electron == nullptr) {
121  throw cms::Exception("MVA failure: ")
122  << " given particle is expected to be reco::GsfElectron or pat::Electron," << std::endl
123  << " but appears to be neither" << std::endl;
124  }
125 
126  return findCategory(*electron);
127 }
128 
130  for (int i = 0; i < getNCategories(); ++i) {
131  if (categoryFunctions_[i](electron))
132  return i;
133  }
134 
135  edm::LogWarning("MVA warning") << "category not defined for particle with pt " << electron.pt() << " GeV, eta "
136  << electron.superCluster()->eta() << " in ElectronMVAEstimatorRun2" << getTag();
137 
138  return -1;
139 }
T getParameter(std::string const &) const
const std::string & getName(int index) const
MVAVariableManager< reco::GsfElectron > mvaVarMngr_
double pt() const final
transverse momentum
float getValue(int index, const ParticleType &particle, const std::vector< float > &auxVariables) const
const std::string & getName() const
const std::string & getTag() const
std::vector< ThreadSafeStringCut< StringCutObjectSelector< reco::GsfElectron >, reco::GsfElectron > > categoryFunctions_
float mvaValue(const reco::Candidate *candidate, std::vector< float > const &auxVariables, int &iCategory) const override
ElectronMVAEstimatorRun2(const edm::ParameterSet &conf)
void init(const std::vector< std::string > &weightFileNames)
std::vector< std::vector< int > > variables_
#define debug
Definition: HDRShower.cc:19
int getVarIndex(const std::string &name)
int findCategory(const reco::Candidate *candidate) const override
std::vector< std::unique_ptr< const GBRForest > > gbrForests_
SuperClusterRef superCluster() const override
reference to a SuperCluster
Definition: GsfElectron.h:155
vars
Definition: DeepTauId.cc:158
std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightsFile)