CMS 3D CMS Logo

ProcTMVA.cc
Go to the documentation of this file.
1 // -*- C++ -*-
2 //
3 // Package: MVAComputer
4 // Class : ProcTMVA
5 //
6 
7 // Implementation:
8 // TMVA wrapper, needs n non-optional, non-multiple input variables
9 // and outputs one result variable. All TMVA algorithms can be used,
10 // calibration data is passed via stream and extracted from a zipped
11 // buffer.
12 //
13 // Author: Christophe Saout
14 // Created: Sat Apr 24 15:18 CEST 2007
15 //
16 
17 #include <sstream>
18 #include <string>
19 #include <vector>
20 #include <memory>
21 #include <iostream>
22 #include <cstdio>
23 
24 // ROOT version magic to support TMVA interface changes in newer ROOT
25 #include <RVersion.h>
26 
27 #include <TMVA/Types.h>
28 #include <TMVA/MethodBase.h>
29 #include "TMVA/Reader.h"
30 
33 
38 
39 #include <boost/filesystem.hpp>
40 
41 using namespace PhysicsTools;
42 
43 namespace { // anonymous
44 
45 class ProcTMVA : public VarProcessor {
46  public:
47  typedef VarProcessor::Registry::Registry<ProcTMVA,
48  Calibration::ProcExternal> Registry;
49 
50  ProcTMVA(const char *name,
52  const MVAComputer *computer);
53  ~ProcTMVA() override {}
54 
55  void configure(ConfIterator iter, unsigned int n) override;
56  void eval(ValueIterator iter, unsigned int n) const override;
57 
58  private:
59  std::unique_ptr<TMVA::Reader> reader;
60  TMVA::MethodBase* method;
61  std::string methodName;
62  unsigned int nVars;
63 
64  // FIXME: Gena
65  TString methodName_t;
66 };
67 
68 ProcTMVA::Registry registry("ProcTMVA");
69 
70 ProcTMVA::ProcTMVA(const char *name,
72  const MVAComputer *computer) :
73  VarProcessor(name, calib, computer)
74 {
75 
76  reader = std::unique_ptr<TMVA::Reader>(new TMVA::Reader( "!Color:Silent" ));
77 
78  ext::imemstream is(
79  reinterpret_cast<const char*>(&calib->store.front()),
80  calib->store.size());
81  ext::izstream izs(&is);
82 
83  std::getline(izs, methodName);
84 
86  std::getline(izs, tmp);
87  std::istringstream iss(tmp);
88  iss >> nVars;
89  for(unsigned int i = 0; i < nVars; i++) {
90  std::getline(izs, tmp);
91  reader->DataInfo().AddVariable(tmp.c_str());
92  }
93 
94  // The rest of the gzip blob is the weights file
95  std::string weight_text;
97  while (std::getline(izs, line)) {
98  weight_text += line;
99  weight_text += "\n";
100  }
101 
102 
103  // Build our reader
104  TMVA::Types::EMVA methodType =
105  TMVA::Types::Instance().GetMethodType(methodName);
106  // Check if xml format
107  if (weight_text.find("<?xml") != std::string::npos) {
108  method = dynamic_cast<TMVA::MethodBase*>( reader->BookMVA( methodType, weight_text.c_str() ) );
109  } else {
110  // Write to a temporary file
111  TString weight_file_name(boost::filesystem::unique_path().c_str());
112  std::ofstream weight_file;
113  weight_file.open(weight_file_name.Data());
114  weight_file << weight_text;
115  weight_file.close();
116  edm::LogInfo("LegacyMVA") << "Building legacy TMVA plugin - "
117  << "the weights are being stored in " << weight_file_name << std::endl;
118  methodName_t.Append(methodName.c_str());
119  method = dynamic_cast<TMVA::MethodBase*>( reader->BookMVA( methodName_t, weight_file_name ) );
120  remove(weight_file_name.Data());
121  }
122 
123  /*
124  bool isXml = false; // weights in XML (TMVA 4) or plain text
125  bool isFirstPass = true;
126  TString weight_file_name(tmpnam(0));
127  std:: ofstream weight_file;
128  //
129 
130  std::string weights;
131  while (izs.good()) {
132  std::string tmp;
133 
134  if (isFirstPass){
135  std::getline(izs, tmp);
136  isFirstPass = false;
137  if ( tmp.find("<?xml") != std::string::npos ){ //xml
138  isXml = true;
139  weights += tmp + " ";
140  }
141  else{
142  std::cout << std::endl;
143  std::cout << "ProcTMVA::ProcTMVA(): *** WARNING! ***" << std::endl;
144  std::cout << "ProcTMVA::ProcTMVA(): Old pre-TMVA 4 plain text weights are being loaded" << std::endl;
145  std::cout << "ProcTMVA::ProcTMVA(): It may work but backwards compatibility is not guaranteed" << std::endl;
146  std::cout << "ProcTMVA::ProcTMVA(): TMVA 4 weight file format is XML" << std::endl;
147  std::cout << "ProcTMVA::ProcTMVA(): Retrain your networks as soon as possible!" << std::endl;
148  std::cout << "ProcTMVA::ProcTMVA(): Creating temporary weight file " << weight_file_name << std::endl;
149  weight_file.open(weight_file_name.Data());
150  weight_file << tmp << std::endl;
151  }
152  } // end first pass
153  else{
154  if (isXml){ // xml
155  izs >> tmp;
156  weights += tmp + " ";
157  }
158  else{ // plain text
159  weight_file << tmp << std::endl;
160  }
161  } // end not first pass
162 
163  }
164  if (weight_file.is_open()){
165  std::cout << "ProcTMVA::ProcTMVA(): Deleting temporary weight file " << weight_file_name << std::endl;
166  weight_file.close();
167  }
168 
169  TMVA::Types::EMVA methodType =
170  TMVA::Types::Instance().GetMethodType(methodName);
171 
172  if (isXml){
173  method = std::unique_ptr<TMVA::MethodBase>
174  ( dynamic_cast<TMVA::MethodBase*>
175  ( reader->BookMVA( methodType, weights.c_str() ) ) );
176  }
177  else{
178  methodName_t.Clear();
179  methodName_t.Append(methodName.c_str());
180  method = std::unique_ptr<TMVA::MethodBase>
181  ( dynamic_cast<TMVA::MethodBase*>
182  ( reader->BookMVA( methodName_t, weight_file_name ) ) );
183  }
184 
185  */
186 }
187 
188 void ProcTMVA::configure(ConfIterator iter, unsigned int n)
189 {
190  if (n != nVars)
191  return;
192 
193  for(unsigned int i = 0; i < n; i++)
194  iter++(Variable::FLAG_NONE);
195 
196  iter << Variable::FLAG_NONE;
197 }
198 
199 void ProcTMVA::eval(ValueIterator iter, unsigned int n) const
200 {
201  std::vector<Float_t> inputs;
202  inputs.reserve(n);
203  for(unsigned int i = 0; i < n; i++)
204  inputs.push_back(*iter++);
205  std::unique_ptr<TMVA::Event> evt(new TMVA::Event(inputs, 2));
206 
207  double result = method->GetMvaValue(evt.get());
208 
209  iter(result);
210 }
211 
212 } // anonymous namespace
template to generate a registry singleton for a type.
#define MVA_COMPUTER_DEFINE_PLUGIN(T)
Main interface class to the generic discriminator computer framework.
Definition: MVAComputer.h:39
std::vector< std::vector< double > > tmp
Definition: MVATrainer.cc:100
static Interceptor::Registry registry("Interceptor")
std::vector< unsigned char > store
Definition: MVAComputer.h:200
Common base class for variable processors.
Definition: VarProcessor.h:36