CMS 3D CMS Logo

ProcTMVA.cc
Go to the documentation of this file.
1 // -*- C++ -*-
2 //
3 // Package: MVAComputer
4 // Class : ProcTMVA
5 //
6 
7 // Implementation:
8 // TMVA wrapper, needs n non-optional, non-multiple input variables
9 // and outputs one result variable. All TMVA algorithms can be used,
10 // calibration data is passed via stream and extracted from a zipped
11 // buffer.
12 //
13 // Author: Christophe Saout
14 // Created: Sat Apr 24 15:18 CEST 2007
15 //
16 
17 #include <memory>
18 #include <fstream>
19 #include <sstream>
20 #include <string>
21 #include <vector>
22 #include <memory>
23 #include <iostream>
24 #include <cstdio>
25 
26 // ROOT version magic to support TMVA interface changes in newer ROOT
27 #include <RVersion.h>
28 
29 #include <TMVA/Types.h>
30 #include <TMVA/MethodBase.h>
31 #include "TMVA/Reader.h"
32 
35 
40 
41 #include <boost/filesystem.hpp>
42 
43 using namespace PhysicsTools;
44 
45 namespace { // anonymous
46 
47  class ProcTMVA : public VarProcessor {
48  public:
50 
51  ProcTMVA(const char *name, const Calibration::ProcExternal *calib, const MVAComputer *computer);
52  ~ProcTMVA() override {}
53 
54  void configure(ConfIterator iter, unsigned int n) override;
55  void eval(ValueIterator iter, unsigned int n) const override;
56 
57  private:
58  std::unique_ptr<TMVA::Reader> reader;
59  TMVA::MethodBase *method;
60  std::string methodName;
61  unsigned int nVars;
62 
63  // FIXME: Gena
64  TString methodName_t;
65  };
66 
67  ProcTMVA::Registry registry("ProcTMVA");
68 
69  ProcTMVA::ProcTMVA(const char *name, const Calibration::ProcExternal *calib, const MVAComputer *computer)
71  reader = std::make_unique<TMVA::Reader>("!Color:Silent");
72 
73  ext::imemstream is(reinterpret_cast<const char *>(&calib->store.front()), calib->store.size());
74  ext::izstream izs(&is);
75 
76  std::getline(izs, methodName);
77 
79  std::getline(izs, tmp);
80  std::istringstream iss(tmp);
81  iss >> nVars;
82  for (unsigned int i = 0; i < nVars; i++) {
83  std::getline(izs, tmp);
84  reader->DataInfo().AddVariable(tmp.c_str());
85  }
86 
87  // The rest of the gzip blob is the weights file
88  std::string weight_text;
90  while (std::getline(izs, line)) {
91  weight_text += line;
92  weight_text += "\n";
93  }
94 
95  // Build our reader
96  TMVA::Types::EMVA methodType = TMVA::Types::Instance().GetMethodType(methodName);
97  // Check if xml format
98  if (weight_text.find("<?xml") != std::string::npos) {
99  method = dynamic_cast<TMVA::MethodBase *>(reader->BookMVA(methodType, weight_text.c_str()));
100  } else {
101  // Write to a temporary file
102  TString weight_file_name(boost::filesystem::unique_path().c_str());
103  std::ofstream weight_file;
104  weight_file.open(weight_file_name.Data());
105  weight_file << weight_text;
106  weight_file.close();
107  edm::LogInfo("LegacyMVA") << "Building legacy TMVA plugin - "
108  << "the weights are being stored in " << weight_file_name << std::endl;
109  methodName_t.Append(methodName.c_str());
110  method = dynamic_cast<TMVA::MethodBase *>(reader->BookMVA(methodName_t, weight_file_name));
111  remove(weight_file_name.Data());
112  }
113 
114  /*
115  bool isXml = false; // weights in XML (TMVA 4) or plain text
116  bool isFirstPass = true;
117  TString weight_file_name(tmpnam(0));
118  std:: ofstream weight_file;
119  //
120 
121  std::string weights;
122  while (izs.good()) {
123  std::string tmp;
124 
125  if (isFirstPass){
126  std::getline(izs, tmp);
127  isFirstPass = false;
128  if ( tmp.find("<?xml") != std::string::npos ){ //xml
129  isXml = true;
130  weights += tmp + " ";
131  }
132  else{
133  std::cout << std::endl;
134  std::cout << "ProcTMVA::ProcTMVA(): *** WARNING! ***" << std::endl;
135  std::cout << "ProcTMVA::ProcTMVA(): Old pre-TMVA 4 plain text weights are being loaded" << std::endl;
136  std::cout << "ProcTMVA::ProcTMVA(): It may work but backwards compatibility is not guaranteed" << std::endl;
137  std::cout << "ProcTMVA::ProcTMVA(): TMVA 4 weight file format is XML" << std::endl;
138  std::cout << "ProcTMVA::ProcTMVA(): Retrain your networks as soon as possible!" << std::endl;
139  std::cout << "ProcTMVA::ProcTMVA(): Creating temporary weight file " << weight_file_name << std::endl;
140  weight_file.open(weight_file_name.Data());
141  weight_file << tmp << std::endl;
142  }
143  } // end first pass
144  else{
145  if (isXml){ // xml
146  izs >> tmp;
147  weights += tmp + " ";
148  }
149  else{ // plain text
150  weight_file << tmp << std::endl;
151  }
152  } // end not first pass
153 
154  }
155  if (weight_file.is_open()){
156  std::cout << "ProcTMVA::ProcTMVA(): Deleting temporary weight file " << weight_file_name << std::endl;
157  weight_file.close();
158  }
159 
160  TMVA::Types::EMVA methodType =
161  TMVA::Types::Instance().GetMethodType(methodName);
162 
163  if (isXml){
164  method = std::unique_ptr<TMVA::MethodBase>
165  ( dynamic_cast<TMVA::MethodBase*>
166  ( reader->BookMVA( methodType, weights.c_str() ) ) );
167  }
168  else{
169  methodName_t.Clear();
170  methodName_t.Append(methodName.c_str());
171  method = std::unique_ptr<TMVA::MethodBase>
172  ( dynamic_cast<TMVA::MethodBase*>
173  ( reader->BookMVA( methodName_t, weight_file_name ) ) );
174  }
175 
176  */
177  }
178 
179  void ProcTMVA::configure(ConfIterator iter, unsigned int n) {
180  if (n != nVars)
181  return;
182 
183  for (unsigned int i = 0; i < n; i++)
184  iter++(Variable::FLAG_NONE);
185 
186  iter << Variable::FLAG_NONE;
187  }
188 
189  void ProcTMVA::eval(ValueIterator iter, unsigned int n) const {
190  std::vector<Float_t> inputs;
191  inputs.reserve(n);
192  for (unsigned int i = 0; i < n; i++)
193  inputs.push_back(*iter++);
194  std::unique_ptr<TMVA::Event> evt(new TMVA::Event(inputs, 2));
195 
196  double result = method->GetMvaValue(evt.get());
197 
198  iter(result);
199  }
200 
201 } // anonymous namespace
template to generate a registry singleton for a type.
reader
Definition: DQM.py:105
#define MVA_COMPUTER_DEFINE_PLUGIN(T)
constexpr int nVars
Main interface class to the generic discriminator computer framework.
Definition: MVAComputer.h:39
Log< level::Info, false > LogInfo
tmp
align.sh
Definition: createJobs.py:716
Common base class for variable processors.
Definition: VarProcessor.h:36