CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
TMVAEvaluator.cc
Go to the documentation of this file.
2 
7 #include "TMVA/MethodBDT.h"
8 
9 
11  mIsInitialized(false), mUsingGBRForest(false), mUseAdaBoost(false), mReleaseAtEnd(false)
12 {
13 }
14 
15 
17 {
18  if (mReleaseAtEnd)
19  mGBRForest.release();
20 }
21 
22 
23 void TMVAEvaluator::initialize(const std::string & options, const std::string & method, const std::string & weightFile,
24  const std::vector<std::string> & variables, const std::vector<std::string> & spectators, bool useGBRForest, bool useAdaBoost)
25 {
26  // initialize the TMVA reader
27  mReader.reset(new TMVA::Reader(options.c_str()));
28  mReader->SetVerbose(false);
29  mMethod = method;
30 
31  // add input variables
32  for(std::vector<std::string>::const_iterator it = variables.begin(); it!=variables.end(); ++it)
33  {
34  mVariables.insert( std::make_pair( *it, std::make_pair( it - variables.begin(), 0. ) ) );
35  mReader->AddVariable(it->c_str(), &(mVariables.at(*it).second));
36  }
37 
38  // add spectator variables
39  for(std::vector<std::string>::const_iterator it = spectators.begin(); it!=spectators.end(); ++it)
40  {
41  mSpectators.insert( std::make_pair( *it, std::make_pair( it - spectators.begin(), 0. ) ) );
42  mReader->AddSpectator(it->c_str(), &(mSpectators.at(*it).second));
43  }
44 
45  // load the TMVA weights
46  mIMethod = std::unique_ptr<TMVA::IMethod>( reco::details::loadTMVAWeights(mReader.get(), mMethod.c_str(), weightFile.c_str()) );
47 
48  if (useGBRForest)
49  {
50  mGBRForest.reset( new GBRForest( dynamic_cast<TMVA::MethodBDT*>( mReader->FindMVA(mMethod.c_str()) ) ) );
51 
52  // now can free some memory
53  mReader.reset(nullptr);
54  mIMethod.reset(nullptr);
55 
56  mUsingGBRForest = true;
57  mUseAdaBoost = useAdaBoost;
58  }
59 
60  mIsInitialized = true;
61 }
62 
63 
64 void TMVAEvaluator::initializeGBRForest(const GBRForest* gbrForest, const std::vector<std::string> & variables,
65  const std::vector<std::string> & spectators, bool useAdaBoost)
66 {
67  // add input variables
68  for(std::vector<std::string>::const_iterator it = variables.begin(); it!=variables.end(); ++it)
69  mVariables.insert( std::make_pair( *it, std::make_pair( it - variables.begin(), 0. ) ) );
70 
71  // add spectator variables
72  for(std::vector<std::string>::const_iterator it = spectators.begin(); it!=spectators.end(); ++it)
73  mSpectators.insert( std::make_pair( *it, std::make_pair( it - spectators.begin(), 0. ) ) );
74 
75  mGBRForest.reset( gbrForest );
76 
77  mIsInitialized = true;
78  mUsingGBRForest = true;
79  mUseAdaBoost = useAdaBoost;
80  mReleaseAtEnd = true; // need to release ownership at the end if getting GBRForest from an external source
81 }
82 
83 
85  const std::vector<std::string> & variables, const std::vector<std::string> & spectators, bool useAdaBoost)
86 {
87  edm::ESHandle<GBRForest> gbrForestHandle;
88 
89  iSetup.get<GBRWrapperRcd>().get(label.c_str(), gbrForestHandle);
90 
91  initializeGBRForest(gbrForestHandle.product(), variables, spectators, useAdaBoost);
92 }
93 
94 
95 float TMVAEvaluator::evaluateTMVA(const std::map<std::string,float> & inputs, bool useSpectators) const
96 {
97  // default value
98  float value = -99.;
99 
100  // TMVA::Reader is not thread safe
101  std::lock_guard<std::mutex> lock(m_mutex);
102 
103  // set the input variable values
104  for(auto it = mVariables.begin(); it!=mVariables.end(); ++it)
105  {
106  if (inputs.count(it->first)>0)
107  it->second.second = inputs.at(it->first);
108  else
109  edm::LogError("MissingInputVariable") << "Input variable " << it->first << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
110  }
111 
112  // if using spectator variables
113  if(useSpectators)
114  {
115  // set the spectator variable values
116  for(auto it = mSpectators.begin(); it!=mSpectators.end(); ++it)
117  {
118  if (inputs.count(it->first)>0)
119  it->second.second = inputs.at(it->first);
120  else
121  edm::LogError("MissingSpectatorVariable") << "Spectator variable " << it->first << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
122  }
123  }
124 
125  // evaluate the MVA
126  value = mReader->EvaluateMVA(mMethod.c_str());
127 
128  return value;
129 }
130 
131 
132 float TMVAEvaluator::evaluateGBRForest(const std::map<std::string,float> & inputs) const
133 {
134  // default value
135  float value = -99.;
136 
137  std::unique_ptr<float[]> vars(new float[mVariables.size()]); // allocate n floats
138 
139  // set the input variable values
140  for(auto it = mVariables.begin(); it!=mVariables.end(); ++it)
141  {
142  if (inputs.count(it->first)>0)
143  vars[it->second.first] = inputs.at(it->first);
144  else
145  edm::LogError("MissingInputVariable") << "Input variable " << it->first << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
146  }
147 
148  // evaluate the MVA
149  if (mUseAdaBoost)
150  value = mGBRForest->GetAdaBoostClassifier(vars.get());
151  else
152  value = mGBRForest->GetGradBoostClassifier(vars.get());
153 
154  return value;
155 }
156 
157 float TMVAEvaluator::evaluate(const std::map<std::string,float> & inputs, bool useSpectators) const
158 {
159  // default value
160  float value = -99.;
161 
162  if(!mIsInitialized)
163  {
164  edm::LogError("InitializationError") << "TMVAEvaluator not properly initialized.";
165  return value;
166  }
167 
168  if( useSpectators && inputs.size() < ( mVariables.size() + mSpectators.size() ) )
169  {
170  edm::LogError("MissingInputs") << "Too few inputs provided (" << inputs.size() << " provided but " << mVariables.size() << " input and " << mSpectators.size() << " spectator variables expected).";
171  return value;
172  }
173  else if( inputs.size() < mVariables.size() )
174  {
175  edm::LogError("MissingInputVariable(s)") << "Too few input variables provided (" << inputs.size() << " provided but " << mVariables.size() << " expected).";
176  return value;
177  }
178 
179  if (mUsingGBRForest)
180  {
181  if(useSpectators)
182  edm::LogWarning("UnsupportedFunctionality") << "Use of spectator variables with GBRForest is not supported. Spectator variables will be ignored.";
183  value = evaluateGBRForest(inputs);
184  }
185  else
186  value = evaluateTMVA(inputs, useSpectators);
187 
188  return value;
189 }
std::map< std::string, std::pair< size_t, float > > mVariables
Definition: TMVAEvaluator.h:44
void initializeGBRForest(const GBRForest *gbrForest, const std::vector< std::string > &variables, const std::vector< std::string > &spectators, bool useAdaBoost=false)
bool mUsingGBRForest
Definition: TMVAEvaluator.h:34
void initialize(const std::string &options, const std::string &method, const std::string &weightFile, const std::vector< std::string > &variables, const std::vector< std::string > &spectators, bool useGBRForest=false, bool useAdaBoost=false)
std::map< std::string, std::pair< size_t, float > > mSpectators
Definition: TMVAEvaluator.h:45
float evaluateTMVA(const std::map< std::string, float > &inputs, bool useSpectators) const
std::string mMethod
Definition: TMVAEvaluator.h:38
const T & get() const
Definition: EventSetup.h:55
T const * product() const
Definition: ESHandle.h:86
std::mutex m_mutex
Definition: TMVAEvaluator.h:39
TMVA::IMethod * loadTMVAWeights(TMVA::Reader *reader, const std::string &method, const std::string &weightFile, bool verbose=false)
std::unique_ptr< TMVA::Reader > mReader
Definition: TMVAEvaluator.h:40
std::unique_ptr< const GBRForest > mGBRForest
Definition: TMVAEvaluator.h:42
float evaluate(const std::map< std::string, float > &inputs, bool useSpectators=false) const
volatile std::atomic< bool > shutdown_flag false
float evaluateGBRForest(const std::map< std::string, float > &inputs) const
std::unique_ptr< TMVA::IMethod > mIMethod
Definition: TMVAEvaluator.h:41