CMS 3D CMS Logo

ElectronMVAEstimatorRun2.cc
Go to the documentation of this file.
3 
6  mvaVarMngr_(conf.getParameter<std::string>("variableDefinition"), MVAVariableHelper::indexMap()) {
7  const auto weightFileNames = conf.getParameter<std::vector<std::string> >("weightFileNames");
8  const auto categoryCutStrings = conf.getParameter<std::vector<std::string> >("categoryCuts");
9 
10  if ((int)(categoryCutStrings.size()) != getNCategories())
11  throw cms::Exception("MVA config failure: ")
12  << "wrong number of category cuts in ElectronMVAEstimatorRun2" << getTag() << std::endl;
13 
14  for (int i = 0; i < getNCategories(); ++i) {
15  categoryFunctions_.emplace_back(categoryCutStrings[i]);
16  }
17 
18  // Initialize GBRForests from weight files
20 }
21 
23  const std::string& mvaName,
24  int nCategories,
26  const std::vector<std::string>& categoryCutStrings,
27  const std::vector<std::string>& weightFileNames,
28  bool debug)
30  mvaVarMngr_(variableDefinition, MVAVariableHelper::indexMap()) {
31  if ((int)(categoryCutStrings.size()) != getNCategories())
32  throw cms::Exception("MVA config failure: ")
33  << "wrong number of category cuts in " << getName() << getTag() << std::endl;
34 
35  for (auto const& cut : categoryCutStrings)
36  categoryFunctions_.emplace_back(cut);
38 }
39 
40 void ElectronMVAEstimatorRun2::init(const std::vector<std::string>& weightFileNames) {
41  if (isDebug()) {
42  std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
43  }
44 
45  // Initialize GBRForests
46  if ((int)(weightFileNames.size()) != getNCategories())
47  throw cms::Exception("MVA config failure: ")
48  << "wrong number of weightfiles in ElectronMVAEstimatorRun2" << getTag() << std::endl;
49 
50  gbrForests_.clear();
51  // Create a TMVA reader object for each category
52  for (int i = 0; i < getNCategories(); i++) {
53  std::vector<int> variablesInCategory;
54 
55  // Use unique_ptr so that all readers are properly cleaned up
56  // when the vector clear() is called in the destructor
57 
58  std::vector<std::string> variableNamesInCategory;
59  gbrForests_.push_back(createGBRForest(weightFileNames[i], variableNamesInCategory));
60 
61  nVariables_.push_back(variableNamesInCategory.size());
62 
63  variables_.push_back(variablesInCategory);
64 
65  if (isDebug()) {
66  std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
67  std::cout << " category " << i << " with nVariables " << nVariables_[i] << std::endl;
68  }
69 
70  for (int j = 0; j < nVariables_[i]; ++j) {
71  int index = mvaVarMngr_.getVarIndex(variableNamesInCategory[j]);
72  if (index == -1) {
73  throw cms::Exception("MVA config failure: ")
74  << "Concerning ElectronMVAEstimatorRun2" << getTag() << std::endl
75  << "Variable " << variableNamesInCategory[j] << " not found in variable definition file!" << std::endl;
76  }
77  variables_[i].push_back(index);
78  }
79  }
80 }
81 
83  const std::vector<float>& auxVariables,
84  int& iCategory) const {
85  const reco::GsfElectron* electron = dynamic_cast<const reco::GsfElectron*>(candidate);
86 
87  if (electron == nullptr) {
88  throw cms::Exception("MVA failure: ")
89  << " given particle is expected to be reco::GsfElectron or pat::Electron," << std::endl
90  << " but appears to be neither" << std::endl;
91  }
92 
93  iCategory = findCategory(electron);
94 
95  if (iCategory < 0)
96  return -999;
97 
98  std::vector<float> vars;
99 
100  vars.reserve(nVariables_[iCategory]);
101  for (int i = 0; i < nVariables_[iCategory]; ++i) {
102  vars.push_back(mvaVarMngr_.getValue(variables_[iCategory][i], *electron, auxVariables));
103  }
104 
105  if (isDebug()) {
106  std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
107  std::cout << " category " << iCategory << std::endl;
108  for (int i = 0; i < nVariables_[iCategory]; ++i) {
109  std::cout << " " << mvaVarMngr_.getName(variables_[iCategory][i]) << " " << vars[i] << std::endl;
110  }
111  }
112  const float response = gbrForests_.at(iCategory)->GetResponse(vars.data()); // The BDT score
113 
114  if (isDebug()) {
115  std::cout << " ### MVA " << response << std::endl << std::endl;
116  }
117 
118  return response;
119 }
120 
122  const reco::GsfElectron* electron = dynamic_cast<reco::GsfElectron const*>(candidate);
123 
124  if (electron == nullptr) {
125  throw cms::Exception("MVA failure: ")
126  << " given particle is expected to be reco::GsfElectron or pat::Electron," << std::endl
127  << " but appears to be neither" << std::endl;
128  }
129 
130  return findCategory(*electron);
131 }
132 
134  for (int i = 0; i < getNCategories(); ++i) {
136  return i;
137  }
138 
139  edm::LogWarning("MVA warning") << "category not defined for particle with pt " << electron.pt() << " GeV, eta "
140  << electron.superCluster()->eta() << " in ElectronMVAEstimatorRun2" << getTag();
141 
142  return -1;
143 }
mvaElectronID_Fall17_iso_V1_cff.mvaTag
mvaTag
Definition: mvaElectronID_Fall17_iso_V1_cff.py:16
mvaElectronID_Fall17_iso_V1_cff.variableDefinition
variableDefinition
Definition: mvaElectronID_Fall17_iso_V1_cff.py:90
mps_fire.i
i
Definition: mps_fire.py:428
TkAlMuonSelectors_cfi.cut
cut
Definition: TkAlMuonSelectors_cfi.py:5
MVAVariableManager::getValue
float getValue(int index, const ParticleType &particle, const std::vector< float > &auxVariables) const
Definition: MVAVariableManager.h:50
AnyMVAEstimatorRun2Base
Definition: AnyMVAEstimatorRun2Base.h:11
gather_cfg.cout
cout
Definition: gather_cfg.py:144
createGBRForest
std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightsFile)
Definition: GBRForestTools.cc:244
ElectronMVAEstimatorRun2::nVariables_
std::vector< int > nVariables_
Definition: ElectronMVAEstimatorRun2.h:47
edm::LogWarning
Log< level::Warning, false > LogWarning
Definition: MessageLogger.h:122
MVAVariableManager::getName
const std::string & getName(int index) const
Definition: MVAVariableManager.h:46
MVAVariableHelper.h
ElectronMVAEstimatorRun2::mvaValue
float mvaValue(const reco::Candidate *candidate, std::vector< float > const &auxVariables, int &iCategory) const override
Definition: ElectronMVAEstimatorRun2.cc:82
debug
#define debug
Definition: HDRShower.cc:19
vars
vars
Definition: DeepTauId.cc:163
metsig::electron
Definition: SignAlgoResolutions.h:48
taus_updatedMVAIds_cff.mvaName
mvaName
Definition: taus_updatedMVAIds_cff.py:18
ElectronMVAEstimatorRun2::variables_
std::vector< std::vector< int > > variables_
Definition: ElectronMVAEstimatorRun2.h:54
AnyMVAEstimatorRun2Base::isDebug
bool isDebug() const
Definition: AnyMVAEstimatorRun2Base.h:46
reco::GsfElectron
Definition: GsfElectron.h:35
ElectronMVAEstimatorRun2::categoryFunctions_
std::vector< ThreadSafeFunctor< StringCutObjectSelector< reco::GsfElectron > > > categoryFunctions_
Definition: ElectronMVAEstimatorRun2.h:46
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
MVAVariableManager::getVarIndex
int getVarIndex(const std::string &name)
Definition: MVAVariableManager.h:37
ElectronMVAEstimatorRun2::ElectronMVAEstimatorRun2
ElectronMVAEstimatorRun2(const edm::ParameterSet &conf)
Definition: ElectronMVAEstimatorRun2.cc:4
edm::ParameterSet
Definition: ParameterSet.h:47
ElectronMVAEstimatorRun2::gbrForests_
std::vector< std::unique_ptr< const GBRForest > > gbrForests_
Definition: ElectronMVAEstimatorRun2.h:50
AnyMVAEstimatorRun2Base::getNCategories
int getNCategories() const
Definition: AnyMVAEstimatorRun2Base.h:39
AnyMVAEstimatorRun2Base::getTag
const std::string & getTag() const
Definition: AnyMVAEstimatorRun2Base.h:44
ElectronMVAEstimatorRun2::findCategory
int findCategory(const reco::Candidate *candidate) const override
Definition: ElectronMVAEstimatorRun2.cc:121
ElectronMVAEstimatorRun2::init
void init(const std::vector< std::string > &weightFileNames)
Definition: ElectronMVAEstimatorRun2.cc:40
mvaElectronID_Fall17_iso_V1_cff.nCategories
nCategories
Definition: mvaElectronID_Fall17_iso_V1_cff.py:86
reco::Candidate
Definition: Candidate.h:27
ElectronMVAEstimatorRun2::mvaVarMngr_
MVAVariableManager< reco::GsfElectron > mvaVarMngr_
Definition: ElectronMVAEstimatorRun2.h:56
ElectronMVAEstimatorRun2.h
std
Definition: JetResolutionObject.h:76
mvaElectronID_Fall17_iso_V1_cff.weightFileNames
weightFileNames
Definition: mvaElectronID_Fall17_iso_V1_cff.py:89
Exception
Definition: hltDiff.cc:246
edm::ParameterSet::getParameter
T getParameter(std::string const &) const
Definition: ParameterSet.h:303
AlignmentPI::index
index
Definition: AlignmentPayloadInspectorHelper.h:46
cms::Exception
Definition: Exception.h:70
dqmiolumiharvest.j
j
Definition: dqmiolumiharvest.py:66
MVAVariableHelper
Definition: MVAVariableHelper.h:12
AnyMVAEstimatorRun2Base::getName
const std::string & getName() const
Definition: AnyMVAEstimatorRun2Base.h:40