CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
TMVAEvaluator.cc
Go to the documentation of this file.
2 
7 #include "TMVA/MethodBDT.h"
8 
9 
11  mIsInitialized(false), mUsingGBRForest(false), mUseAdaBoost(false)
12 {
13 }
14 
15 
16 void TMVAEvaluator::initialize(const std::string & options, const std::string & method, const std::string & weightFile,
17  const std::vector<std::string> & variables, const std::vector<std::string> & spectators, bool useGBRForest, bool useAdaBoost)
18 {
19  // initialize the TMVA reader
20  mReader.reset(new TMVA::Reader(options.c_str()));
21  mReader->SetVerbose(false);
22  mMethod = method;
23 
24  // add input variables
25  for(std::vector<std::string>::const_iterator it = variables.begin(); it!=variables.end(); ++it)
26  {
27  mVariables.insert( std::make_pair( *it, std::make_pair( it - variables.begin(), 0. ) ) );
28  mReader->AddVariable(it->c_str(), &(mVariables.at(*it).second));
29  }
30 
31  // add spectator variables
32  for(std::vector<std::string>::const_iterator it = spectators.begin(); it!=spectators.end(); ++it)
33  {
34  mSpectators.insert( std::make_pair( *it, std::make_pair( it - spectators.begin(), 0. ) ) );
35  mReader->AddSpectator(it->c_str(), &(mSpectators.at(*it).second));
36  }
37 
38  // load the TMVA weights
39  mIMethod = std::unique_ptr<TMVA::IMethod>( reco::details::loadTMVAWeights(mReader.get(), mMethod.c_str(), weightFile.c_str()) );
40 
41  if (useGBRForest)
42  {
43  mGBRForest.reset( new GBRForest( dynamic_cast<TMVA::MethodBDT*>( mReader->FindMVA(mMethod.c_str()) ) ) );
44 
45  // now can free some memory
46  mReader.reset(nullptr);
47  mIMethod.reset(nullptr);
48 
49  mUsingGBRForest = true;
50  mUseAdaBoost = useAdaBoost;
51  }
52 
53  mIsInitialized = true;
54 }
55 
56 
57 void TMVAEvaluator::initializeGBRForest(const GBRForest* gbrForest, const std::vector<std::string> & variables,
58  const std::vector<std::string> & spectators, bool useAdaBoost)
59 {
60  // add input variables
61  for(std::vector<std::string>::const_iterator it = variables.begin(); it!=variables.end(); ++it)
62  mVariables.insert( std::make_pair( *it, std::make_pair( it - variables.begin(), 0. ) ) );
63 
64  // add spectator variables
65  for(std::vector<std::string>::const_iterator it = spectators.begin(); it!=spectators.end(); ++it)
66  mSpectators.insert( std::make_pair( *it, std::make_pair( it - spectators.begin(), 0. ) ) );
67 
68  // do not take ownership if getting GBRForest from an external source
69  mGBRForest = std::shared_ptr<const GBRForest>(gbrForest, [](const GBRForest*) {} );
70 
71  mIsInitialized = true;
72  mUsingGBRForest = true;
73  mUseAdaBoost = useAdaBoost;
74 }
75 
76 
78  const std::vector<std::string> & variables, const std::vector<std::string> & spectators, bool useAdaBoost)
79 {
80  edm::ESHandle<GBRForest> gbrForestHandle;
81 
82  iSetup.get<GBRWrapperRcd>().get(label.c_str(), gbrForestHandle);
83 
84  initializeGBRForest(gbrForestHandle.product(), variables, spectators, useAdaBoost);
85 }
86 
87 
88 float TMVAEvaluator::evaluateTMVA(const std::map<std::string,float> & inputs, bool useSpectators) const
89 {
90  // default value
91  float value = -99.;
92 
93  // TMVA::Reader is not thread safe
94  std::lock_guard<std::mutex> lock(m_mutex);
95 
96  // set the input variable values
97  for(auto it = mVariables.begin(); it!=mVariables.end(); ++it)
98  {
99  if (inputs.count(it->first)>0)
100  it->second.second = inputs.at(it->first);
101  else
102  edm::LogError("MissingInputVariable") << "Input variable " << it->first << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
103  }
104 
105  // if using spectator variables
106  if(useSpectators)
107  {
108  // set the spectator variable values
109  for(auto it = mSpectators.begin(); it!=mSpectators.end(); ++it)
110  {
111  if (inputs.count(it->first)>0)
112  it->second.second = inputs.at(it->first);
113  else
114  edm::LogError("MissingSpectatorVariable") << "Spectator variable " << it->first << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
115  }
116  }
117 
118  // evaluate the MVA
119  value = mReader->EvaluateMVA(mMethod.c_str());
120 
121  return value;
122 }
123 
124 
125 float TMVAEvaluator::evaluateGBRForest(const std::map<std::string,float> & inputs) const
126 {
127  // default value
128  float value = -99.;
129 
130  std::unique_ptr<float[]> vars(new float[mVariables.size()]); // allocate n floats
131 
132  // set the input variable values
133  for(auto it = mVariables.begin(); it!=mVariables.end(); ++it)
134  {
135  if (inputs.count(it->first)>0)
136  vars[it->second.first] = inputs.at(it->first);
137  else
138  edm::LogError("MissingInputVariable") << "Input variable " << it->first << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
139  }
140 
141  // evaluate the MVA
142  if (mUseAdaBoost)
143  value = mGBRForest->GetAdaBoostClassifier(vars.get());
144  else
145  value = mGBRForest->GetGradBoostClassifier(vars.get());
146 
147  return value;
148 }
149 
150 float TMVAEvaluator::evaluate(const std::map<std::string,float> & inputs, bool useSpectators) const
151 {
152  // default value
153  float value = -99.;
154 
155  if(!mIsInitialized)
156  {
157  edm::LogError("InitializationError") << "TMVAEvaluator not properly initialized.";
158  return value;
159  }
160 
161  if( useSpectators && inputs.size() < ( mVariables.size() + mSpectators.size() ) )
162  {
163  edm::LogError("MissingInputs") << "Too few inputs provided (" << inputs.size() << " provided but " << mVariables.size() << " input and " << mSpectators.size() << " spectator variables expected).";
164  return value;
165  }
166  else if( inputs.size() < mVariables.size() )
167  {
168  edm::LogError("MissingInputVariable(s)") << "Too few input variables provided (" << inputs.size() << " provided but " << mVariables.size() << " expected).";
169  return value;
170  }
171 
172  if (mUsingGBRForest)
173  {
174  if(useSpectators)
175  edm::LogWarning("UnsupportedFunctionality") << "Use of spectator variables with GBRForest is not supported. Spectator variables will be ignored.";
176  value = evaluateGBRForest(inputs);
177  }
178  else
179  value = evaluateTMVA(inputs, useSpectators);
180 
181  return value;
182 }
std::map< std::string, std::pair< size_t, float > > mVariables
Definition: TMVAEvaluator.h:42
void initializeGBRForest(const GBRForest *gbrForest, const std::vector< std::string > &variables, const std::vector< std::string > &spectators, bool useAdaBoost=false)
bool mUsingGBRForest
Definition: TMVAEvaluator.h:33
void initialize(const std::string &options, const std::string &method, const std::string &weightFile, const std::vector< std::string > &variables, const std::vector< std::string > &spectators, bool useGBRForest=false, bool useAdaBoost=false)
std::map< std::string, std::pair< size_t, float > > mSpectators
Definition: TMVAEvaluator.h:43
float evaluateTMVA(const std::map< std::string, float > &inputs, bool useSpectators) const
std::shared_ptr< const GBRForest > mGBRForest
Definition: TMVAEvaluator.h:40
std::string mMethod
Definition: TMVAEvaluator.h:36
const T & get() const
Definition: EventSetup.h:56
T const * product() const
Definition: ESHandle.h:86
std::mutex m_mutex
Definition: TMVAEvaluator.h:37
TMVA::IMethod * loadTMVAWeights(TMVA::Reader *reader, const std::string &method, const std::string &weightFile, bool verbose=false)
std::unique_ptr< TMVA::Reader > mReader
Definition: TMVAEvaluator.h:38
float evaluate(const std::map< std::string, float > &inputs, bool useSpectators=false) const
volatile std::atomic< bool > shutdown_flag false
float evaluateGBRForest(const std::map< std::string, float > &inputs) const
std::unique_ptr< TMVA::IMethod > mIMethod
Definition: TMVAEvaluator.h:39