CMS 3D CMS Logo

TMVAEvaluator.cc
Go to the documentation of this file.
1 #include <memory>
2 
6 
10 
11 TMVAEvaluator::TMVAEvaluator() : mIsInitialized(false), mUsingGBRForest(false), mUseAdaBoost(false) {}
12 
14  const std::string& method,
15  const std::string& weightFile,
16  const std::vector<std::string>& variables,
17  const std::vector<std::string>& spectators,
18  bool useGBRForest,
19  bool useAdaBoost) {
20  // initialize the TMVA reader
21  mReader = std::make_unique<TMVA::Reader>(options.c_str());
22  mReader->SetVerbose(false);
23  mMethod = method;
24 
25  // add input variables
26  for (std::vector<std::string>::const_iterator it = variables.begin(); it != variables.end(); ++it) {
27  mVariables.insert(std::make_pair(*it, std::make_pair(it - variables.begin(), 0.)));
28  mReader->AddVariable(it->c_str(), &(mVariables.at(*it).second));
29  }
30 
31  // add spectator variables
32  for (std::vector<std::string>::const_iterator it = spectators.begin(); it != spectators.end(); ++it) {
33  mSpectators.insert(std::make_pair(*it, std::make_pair(it - spectators.begin(), 0.)));
34  mReader->AddSpectator(it->c_str(), &(mSpectators.at(*it).second));
35  }
36 
37  // load the TMVA weights
39 
40  if (useGBRForest) {
42 
43  // now can free some memory
44  mReader.reset(nullptr);
45 
46  mUsingGBRForest = true;
48  }
49 
50  mIsInitialized = true;
51 }
52 
54  const std::vector<std::string>& variables,
55  const std::vector<std::string>& spectators,
56  bool useAdaBoost) {
57  // add input variables
58  for (std::vector<std::string>::const_iterator it = variables.begin(); it != variables.end(); ++it)
59  mVariables.insert(std::make_pair(*it, std::make_pair(it - variables.begin(), 0.)));
60 
61  // add spectator variables
62  for (std::vector<std::string>::const_iterator it = spectators.begin(); it != spectators.end(); ++it)
63  mSpectators.insert(std::make_pair(*it, std::make_pair(it - spectators.begin(), 0.)));
64 
65  // do not take ownership if getting GBRForest from an external source
66  mGBRForest = std::shared_ptr<const GBRForest>(gbrForest, [](const GBRForest*) {});
67 
68  mIsInitialized = true;
69  mUsingGBRForest = true;
71 }
72 
74  const std::string& label,
75  const std::vector<std::string>& variables,
76  const std::vector<std::string>& spectators,
77  bool useAdaBoost) {
78  edm::ESHandle<GBRForest> gbrForestHandle;
79 
80  iSetup.get<GBRWrapperRcd>().get(label.c_str(), gbrForestHandle);
81 
83 }
84 
85 float TMVAEvaluator::evaluateTMVA(const std::map<std::string, float>& inputs, bool useSpectators) const {
86  // default value
87  float value = -99.;
88 
89  // TMVA::Reader is not thread safe
90  std::lock_guard<std::mutex> lock(m_mutex);
91 
92  // set the input variable values
93  for (auto it = mVariables.begin(); it != mVariables.end(); ++it) {
94  if (inputs.count(it->first) > 0)
95  it->second.second = inputs.at(it->first);
96  else
97  edm::LogError("MissingInputVariable")
98  << "Input variable " << it->first
99  << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
100  }
101 
102  // if using spectator variables
103  if (useSpectators) {
104  // set the spectator variable values
105  for (auto it = mSpectators.begin(); it != mSpectators.end(); ++it) {
106  if (inputs.count(it->first) > 0)
107  it->second.second = inputs.at(it->first);
108  else
109  edm::LogError("MissingSpectatorVariable")
110  << "Spectator variable " << it->first
111  << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
112  }
113  }
114 
115  // evaluate the MVA
116  value = mReader->EvaluateMVA(mMethod.c_str());
117 
118  return value;
119 }
120 
121 float TMVAEvaluator::evaluateGBRForest(const std::map<std::string, float>& inputs) const {
122  // default value
123  float value = -99.;
124 
125  std::unique_ptr<float[]> vars(new float[mVariables.size()]); // allocate n floats
126 
127  // set the input variable values
128  for (auto it = mVariables.begin(); it != mVariables.end(); ++it) {
129  if (inputs.count(it->first) > 0)
130  vars[it->second.first] = inputs.at(it->first);
131  else
132  edm::LogError("MissingInputVariable")
133  << "Input variable " << it->first
134  << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
135  }
136 
137  // evaluate the MVA
138  if (mUseAdaBoost)
139  value = mGBRForest->GetAdaBoostClassifier(vars.get());
140  else
141  value = mGBRForest->GetGradBoostClassifier(vars.get());
142 
143  return value;
144 }
145 
146 float TMVAEvaluator::evaluate(const std::map<std::string, float>& inputs, bool useSpectators) const {
147  // default value
148  float value = -99.;
149 
150  if (!mIsInitialized) {
151  edm::LogError("InitializationError") << "TMVAEvaluator not properly initialized.";
152  return value;
153  }
154 
155  if (useSpectators && inputs.size() < (mVariables.size() + mSpectators.size())) {
156  edm::LogError("MissingInputs") << "Too few inputs provided (" << inputs.size() << " provided but "
157  << mVariables.size() << " input and " << mSpectators.size()
158  << " spectator variables expected).";
159  return value;
160  } else if (inputs.size() < mVariables.size()) {
161  edm::LogError("MissingInputVariable(s)") << "Too few input variables provided (" << inputs.size()
162  << " provided but " << mVariables.size() << " expected).";
163  return value;
164  }
165 
166  if (mUsingGBRForest) {
167  if (useSpectators)
168  edm::LogWarning("UnsupportedFunctionality")
169  << "Use of spectator variables with GBRForest is not supported. Spectator variables will be ignored.";
171  } else
172  value = evaluateTMVA(inputs, useSpectators);
173 
174  return value;
175 }
TMVAEvaluator::evaluate
float evaluate(const std::map< std::string, float > &inputs, bool useSpectators=false) const
Definition: TMVAEvaluator.cc:146
TMVAEvaluator::mReader
std::unique_ptr< TMVA::Reader > mReader
Definition: TMVAEvaluator.h:50
edm::ESHandle::product
T const * product() const
Definition: ESHandle.h:86
TMVAEvaluator::mMethod
std::string mMethod
Definition: TMVAEvaluator.h:48
MessageLogger.h
funct::false
false
Definition: Factorize.h:29
ESHandle.h
L1TEGammaDiff_cfi.variables
variables
Definition: L1TEGammaDiff_cfi.py:5
TMVAEvaluator::mUsingGBRForest
bool mUsingGBRForest
Definition: TMVAEvaluator.h:45
candidateCombinedMVAV2Computer_cfi.spectators
spectators
Definition: candidateCombinedMVAV2Computer_cfi.py:15
GBRWrapperRcd.h
GBRForestTools.h
GBRForest
Definition: GBRForest.h:25
AlcaSiPixelAliHarvester0T_cff.method
method
Definition: AlcaSiPixelAliHarvester0T_cff.py:41
HLT_FULL_cff.weightFile
weightFile
Definition: HLT_FULL_cff.py:6777
createGBRForest
std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightsFile)
Definition: GBRForestTools.cc:244
TMVAEvaluator::mIsInitialized
bool mIsInitialized
Definition: TMVAEvaluator.h:44
edm::LogWarning
Log< level::Warning, false > LogWarning
Definition: MessageLogger.h:122
options
Definition: options.py:1
TMVAEvaluator::evaluateGBRForest
float evaluateGBRForest(const std::map< std::string, float > &inputs) const
Definition: TMVAEvaluator.cc:121
edm::EventSetup::get
T get() const
Definition: EventSetup.h:80
vars
vars
Definition: DeepTauId.cc:163
HLT_FULL_cff.useAdaBoost
useAdaBoost
Definition: HLT_FULL_cff.py:6780
TMVAEvaluator::TMVAEvaluator
TMVAEvaluator()
Definition: TMVAEvaluator.cc:11
edm::ESHandle< GBRForest >
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
TMVAEvaluator::mVariables
std::map< std::string, std::pair< size_t, float > > mVariables
Definition: TMVAEvaluator.h:53
TMVAEvaluator.h
CommonMethods.lock
def lock()
Definition: CommonMethods.py:82
TMVAEvaluator::m_mutex
std::mutex m_mutex
Definition: TMVAEvaluator.h:49
TMVAEvaluator::mGBRForest
std::shared_ptr< const GBRForest > mGBRForest
Definition: TMVAEvaluator.h:51
value
Definition: value.py:1
PixelMapPlotter.inputs
inputs
Definition: PixelMapPlotter.py:490
edm::EventSetup
Definition: EventSetup.h:57
TMVAEvaluator::mUseAdaBoost
bool mUseAdaBoost
Definition: TMVAEvaluator.h:46
edm::LogError
Log< level::Error, false > LogError
Definition: MessageLogger.h:123
get
#define get
TMVAEvaluator::initializeGBRForest
void initializeGBRForest(const GBRForest *gbrForest, const std::vector< std::string > &variables, const std::vector< std::string > &spectators, bool useAdaBoost=false)
Definition: TMVAEvaluator.cc:53
TMVAEvaluator::initialize
void initialize(const std::string &options, const std::string &method, const std::string &weightFile, const std::vector< std::string > &variables, const std::vector< std::string > &spectators, bool useGBRForest=false, bool useAdaBoost=false)
Definition: TMVAEvaluator.cc:13
relativeConstraints.value
value
Definition: relativeConstraints.py:53
reco::details::loadTMVAWeights
TMVA::IMethod * loadTMVAWeights(TMVA::Reader *reader, const std::string &method, const std::string &weightFile, bool verbose=false)
Definition: TMVAZipReader.cc:52
GBRWrapperRcd
Definition: GBRWrapperRcd.h:24
TMVAZipReader.h
TMVAEvaluator::evaluateTMVA
float evaluateTMVA(const std::map< std::string, float > &inputs, bool useSpectators) const
Definition: TMVAEvaluator.cc:85
HLT_FULL_cff.useGBRForest
useGBRForest
Definition: HLT_FULL_cff.py:6779
TMVAEvaluator::mSpectators
std::map< std::string, std::pair< size_t, float > > mSpectators
Definition: TMVAEvaluator.h:54
label
const char * label
Definition: PFTauDecayModeTools.cc:11