CMS 3D CMS Logo

TMVAEvaluator.cc
Go to the documentation of this file.
2 
7 #include "TMVA/MethodBDT.h"
8 
9 
11  mIsInitialized(false), mUsingGBRForest(false), mUseAdaBoost(false)
12 {
13 }
14 
16  const std::vector<std::string> & variables, const std::vector<std::string> & spectators, bool useGBRForest, bool useAdaBoost)
17 {
18  // initialize the TMVA reader
19  mReader.reset(new TMVA::Reader(options.c_str()));
20  mReader->SetVerbose(false);
21  mMethod = method;
22 
23  // add input variables
24  for(std::vector<std::string>::const_iterator it = variables.begin(); it!=variables.end(); ++it)
25  {
26  mVariables.insert( std::make_pair( *it, std::make_pair( it - variables.begin(), 0. ) ) );
27  mReader->AddVariable(it->c_str(), &(mVariables.at(*it).second));
28  }
29 
30  // add spectator variables
31  for(std::vector<std::string>::const_iterator it = spectators.begin(); it!=spectators.end(); ++it)
32  {
33  mSpectators.insert( std::make_pair( *it, std::make_pair( it - spectators.begin(), 0. ) ) );
34  mReader->AddSpectator(it->c_str(), &(mSpectators.at(*it).second));
35  }
36 
37  // load the TMVA weights
39 
40  if (useGBRForest)
41  {
42  mGBRForest.reset( new GBRForest( dynamic_cast<TMVA::MethodBDT*>( mReader->FindMVA(mMethod.c_str()) ) ) );
43 
44  // now can free some memory
45  mReader.reset(nullptr);
46 
47  mUsingGBRForest = true;
49  }
50 
51  mIsInitialized = true;
52 }
53 
54 
55 void TMVAEvaluator::initializeGBRForest(const GBRForest* gbrForest, const std::vector<std::string> & variables,
56  const std::vector<std::string> & spectators, bool useAdaBoost)
57 {
58  // add input variables
59  for(std::vector<std::string>::const_iterator it = variables.begin(); it!=variables.end(); ++it)
60  mVariables.insert( std::make_pair( *it, std::make_pair( it - variables.begin(), 0. ) ) );
61 
62  // add spectator variables
63  for(std::vector<std::string>::const_iterator it = spectators.begin(); it!=spectators.end(); ++it)
64  mSpectators.insert( std::make_pair( *it, std::make_pair( it - spectators.begin(), 0. ) ) );
65 
66  // do not take ownership if getting GBRForest from an external source
67  mGBRForest = std::shared_ptr<const GBRForest>(gbrForest, [](const GBRForest*) {} );
68 
69  mIsInitialized = true;
70  mUsingGBRForest = true;
72 }
73 
74 
76  const std::vector<std::string> & variables, const std::vector<std::string> & spectators, bool useAdaBoost)
77 {
78  edm::ESHandle<GBRForest> gbrForestHandle;
79 
80  iSetup.get<GBRWrapperRcd>().get(label.c_str(), gbrForestHandle);
81 
83 }
84 
85 
86 float TMVAEvaluator::evaluateTMVA(const std::map<std::string,float> & inputs, bool useSpectators) const
87 {
88  // default value
89  float value = -99.;
90 
91  // TMVA::Reader is not thread safe
92  std::lock_guard<std::mutex> lock(m_mutex);
93 
94  // set the input variable values
95  for(auto it = mVariables.begin(); it!=mVariables.end(); ++it)
96  {
97  if (inputs.count(it->first)>0)
98  it->second.second = inputs.at(it->first);
99  else
100  edm::LogError("MissingInputVariable") << "Input variable " << it->first << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
101  }
102 
103  // if using spectator variables
104  if(useSpectators)
105  {
106  // set the spectator variable values
107  for(auto it = mSpectators.begin(); it!=mSpectators.end(); ++it)
108  {
109  if (inputs.count(it->first)>0)
110  it->second.second = inputs.at(it->first);
111  else
112  edm::LogError("MissingSpectatorVariable") << "Spectator variable " << it->first << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
113  }
114  }
115 
116  // evaluate the MVA
117  value = mReader->EvaluateMVA(mMethod.c_str());
118 
119  return value;
120 }
121 
122 
123 float TMVAEvaluator::evaluateGBRForest(const std::map<std::string,float> & inputs) const
124 {
125  // default value
126  float value = -99.;
127 
128  std::unique_ptr<float[]> vars(new float[mVariables.size()]); // allocate n floats
129 
130  // set the input variable values
131  for(auto it = mVariables.begin(); it!=mVariables.end(); ++it)
132  {
133  if (inputs.count(it->first)>0)
134  vars[it->second.first] = inputs.at(it->first);
135  else
136  edm::LogError("MissingInputVariable") << "Input variable " << it->first << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
137  }
138 
139  // evaluate the MVA
140  if (mUseAdaBoost)
141  value = mGBRForest->GetAdaBoostClassifier(vars.get());
142  else
143  value = mGBRForest->GetGradBoostClassifier(vars.get());
144 
145  return value;
146 }
147 
148 float TMVAEvaluator::evaluate(const std::map<std::string,float> & inputs, bool useSpectators) const
149 {
150  // default value
151  float value = -99.;
152 
153  if(!mIsInitialized)
154  {
155  edm::LogError("InitializationError") << "TMVAEvaluator not properly initialized.";
156  return value;
157  }
158 
159  if( useSpectators && inputs.size() < ( mVariables.size() + mSpectators.size() ) )
160  {
161  edm::LogError("MissingInputs") << "Too few inputs provided (" << inputs.size() << " provided but " << mVariables.size() << " input and " << mSpectators.size() << " spectator variables expected).";
162  return value;
163  }
164  else if( inputs.size() < mVariables.size() )
165  {
166  edm::LogError("MissingInputVariable(s)") << "Too few input variables provided (" << inputs.size() << " provided but " << mVariables.size() << " expected).";
167  return value;
168  }
169 
170  if (mUsingGBRForest)
171  {
172  if(useSpectators)
173  edm::LogWarning("UnsupportedFunctionality") << "Use of spectator variables with GBRForest is not supported. Spectator variables will be ignored.";
174  value = evaluateGBRForest(inputs);
175  }
176  else
177  value = evaluateTMVA(inputs, useSpectators);
178 
179  return value;
180 }
std::map< std::string, std::pair< size_t, float > > mVariables
Definition: TMVAEvaluator.h:42
void initializeGBRForest(const GBRForest *gbrForest, const std::vector< std::string > &variables, const std::vector< std::string > &spectators, bool useAdaBoost=false)
bool mUsingGBRForest
Definition: TMVAEvaluator.h:34
void initialize(const std::string &options, const std::string &method, const std::string &weightFile, const std::vector< std::string > &variables, const std::vector< std::string > &spectators, bool useGBRForest=false, bool useAdaBoost=false)
std::map< std::string, std::pair< size_t, float > > mSpectators
Definition: TMVAEvaluator.h:43
float evaluateTMVA(const std::map< std::string, float > &inputs, bool useSpectators) const
std::shared_ptr< const GBRForest > mGBRForest
Definition: TMVAEvaluator.h:40
Definition: value.py:1
std::string mMethod
Definition: TMVAEvaluator.h:37
const T & get() const
Definition: EventSetup.h:58
std::mutex m_mutex
Definition: TMVAEvaluator.h:38
TMVA::IMethod * loadTMVAWeights(TMVA::Reader *reader, const std::string &method, const std::string &weightFile, bool verbose=false)
std::unique_ptr< TMVA::Reader > mReader
Definition: TMVAEvaluator.h:39
float evaluate(const std::map< std::string, float > &inputs, bool useSpectators=false) const
float evaluateGBRForest(const std::map< std::string, float > &inputs) const
T const * product() const
Definition: ESHandle.h:86