CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
ProcTMVA.cc
Go to the documentation of this file.
1 // -*- C++ -*-
2 //
3 // Package: MVAComputer
4 // Class : ProcTMVA
5 //
6 
7 // Implementation:
8 // TMVA wrapper, needs n non-optional, non-multiple input variables
9 // and outputs one result variable. All TMVA algorithms can be used,
10 // calibration data is passed via stream and extracted from a zipped
11 // buffer.
12 //
13 // Author: Christophe Saout
14 // Created: Sat Apr 24 15:18 CEST 2007
15 // $Id: ProcTMVA.cc,v 1.7 2012/11/16 22:28:55 muzaffar Exp $
16 //
17 
18 #include <sstream>
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include <iostream>
23 #include <cstdio>
24 
25 // ROOT version magic to support TMVA interface changes in newer ROOT
26 #include <RVersion.h>
27 
28 #include <TMVA/Types.h>
29 #include <TMVA/MethodBase.h>
30 #include "TMVA/Reader.h"
31 
34 
38 
39 using namespace PhysicsTools;
40 
41 namespace { // anonymous
42 
43 class ProcTMVA : public VarProcessor {
44  public:
45  typedef VarProcessor::Registry::Registry<ProcTMVA,
47 
48  ProcTMVA(const char *name,
50  const MVAComputer *computer);
51  virtual ~ProcTMVA() {}
52 
53  virtual void configure(ConfIterator iter, unsigned int n);
54  virtual void eval(ValueIterator iter, unsigned int n) const;
55 
56  private:
57  std::auto_ptr<TMVA::Reader> reader;
58  std::auto_ptr<TMVA::MethodBase> method;
59  std::string methodName;
60  unsigned int nVars;
61 
62  // FIXME: Gena
63  TString methodName_t;
64 };
65 
66 static ProcTMVA::Registry registry("ProcTMVA");
67 
68 ProcTMVA::ProcTMVA(const char *name,
70  const MVAComputer *computer) :
71  VarProcessor(name, calib, computer)
72 {
73 
74  reader = std::auto_ptr<TMVA::Reader>(new TMVA::Reader( "!Color:Silent" ));
75 
76  ext::imemstream is(
77  reinterpret_cast<const char*>(&calib->store.front()),
78  calib->store.size());
79  ext::izstream izs(&is);
80 
81  std::getline(izs, methodName);
82 
83  std::string tmp;
84  std::getline(izs, tmp);
85  std::istringstream iss(tmp);
86  iss >> nVars;
87  for(unsigned int i = 0; i < nVars; i++) {
88  std::getline(izs, tmp);
89  reader->DataInfo().AddVariable(tmp.c_str());
90  }
91 
92  // The rest of the gzip blob is the weights file
93  std::string weight_text;
94  std::string line;
95  while (std::getline(izs, line)) {
96  weight_text += line;
97  weight_text += "\n";
98  }
99 
100 
101  // Build our reader
102  TMVA::Types::EMVA methodType =
103  TMVA::Types::Instance().GetMethodType(methodName);
104  // Check if xml format
105  if (weight_text.find("<?xml") != std::string::npos) {
106  method = std::auto_ptr<TMVA::MethodBase>
107  ( dynamic_cast<TMVA::MethodBase*>
108  ( reader->BookMVA( methodType, weight_text.c_str() ) ) );
109  } else {
110  // Write to a temporary file
111  TString weight_file_name(std::tmpnam(NULL));
112  std::ofstream weight_file;
113  weight_file.open(weight_file_name.Data());
114  weight_file << weight_text;
115  weight_file.close();
116  edm::LogInfo("LegacyMVA") << "Building legacy TMVA plugin - "
117  << "the weights are being stored in " << weight_file_name << std::endl;
118  methodName_t.Append(methodName.c_str());
119  method = std::auto_ptr<TMVA::MethodBase>
120  ( dynamic_cast<TMVA::MethodBase*>
121  ( reader->BookMVA( methodName_t, weight_file_name ) ) );
122  remove(weight_file_name.Data());
123  }
124 
125  /*
126  bool isXml = false; // weights in XML (TMVA 4) or plain text
127  bool isFirstPass = true;
128  TString weight_file_name(tmpnam(0));
129  std:: ofstream weight_file;
130  //
131 
132  std::string weights;
133  while (izs.good()) {
134  std::string tmp;
135 
136  if (isFirstPass){
137  std::getline(izs, tmp);
138  isFirstPass = false;
139  if ( tmp.find("<?xml") != std::string::npos ){ //xml
140  isXml = true;
141  weights += tmp + " ";
142  }
143  else{
144  std::cout << std::endl;
145  std::cout << "ProcTMVA::ProcTMVA(): *** WARNING! ***" << std::endl;
146  std::cout << "ProcTMVA::ProcTMVA(): Old pre-TMVA 4 plain text weights are being loaded" << std::endl;
147  std::cout << "ProcTMVA::ProcTMVA(): It may work but backwards compatibility is not guaranteed" << std::endl;
148  std::cout << "ProcTMVA::ProcTMVA(): TMVA 4 weight file format is XML" << std::endl;
149  std::cout << "ProcTMVA::ProcTMVA(): Retrain your networks as soon as possible!" << std::endl;
150  std::cout << "ProcTMVA::ProcTMVA(): Creating temporary weight file " << weight_file_name << std::endl;
151  weight_file.open(weight_file_name.Data());
152  weight_file << tmp << std::endl;
153  }
154  } // end first pass
155  else{
156  if (isXml){ // xml
157  izs >> tmp;
158  weights += tmp + " ";
159  }
160  else{ // plain text
161  weight_file << tmp << std::endl;
162  }
163  } // end not first pass
164 
165  }
166  if (weight_file.is_open()){
167  std::cout << "ProcTMVA::ProcTMVA(): Deleting temporary weight file " << weight_file_name << std::endl;
168  weight_file.close();
169  }
170 
171  TMVA::Types::EMVA methodType =
172  TMVA::Types::Instance().GetMethodType(methodName);
173 
174  if (isXml){
175  method = std::auto_ptr<TMVA::MethodBase>
176  ( dynamic_cast<TMVA::MethodBase*>
177  ( reader->BookMVA( methodType, weights.c_str() ) ) );
178  }
179  else{
180  methodName_t.Clear();
181  methodName_t.Append(methodName.c_str());
182  method = std::auto_ptr<TMVA::MethodBase>
183  ( dynamic_cast<TMVA::MethodBase*>
184  ( reader->BookMVA( methodName_t, weight_file_name ) ) );
185  }
186 
187  */
188 }
189 
190 void ProcTMVA::configure(ConfIterator iter, unsigned int n)
191 {
192  if (n != nVars)
193  return;
194 
195  for(unsigned int i = 0; i < n; i++)
196  iter++(Variable::FLAG_NONE);
197 
198  iter << Variable::FLAG_NONE;
199 }
200 
201 void ProcTMVA::eval(ValueIterator iter, unsigned int n) const
202 {
203  std::vector<Float_t> inputs;
204  inputs.reserve(n);
205  for(unsigned int i = 0; i < n; i++)
206  inputs.push_back(*iter++);
207  std::auto_ptr<TMVA::Event> evt(new TMVA::Event(inputs, 2));
208 
209  double result = method->GetMvaValue(evt.get());
210 
211  iter(result);
212 }
213 
214 } // anonymous namespace
int i
Definition: DBlmapReader.cc:9
#define NULL
Definition: scimark2.h:8
detail::ThreadSafeRegistry< ParameterSetID, ParameterSet, ProcessParameterSetIDCache > Registry
Definition: Registry.h:37
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
Main interface class to the generic discriminator computer framework.
Definition: MVAComputer.h:40
tuple result
Definition: query.py:137
std::vector< std::vector< double > > tmp
Definition: MVATrainer.cc:100
template to generate a registry singleton for a type.
static Interceptor::Registry registry("Interceptor")
std::vector< unsigned char > store
Definition: MVAComputer.h:148
MVA_COMPUTER_DEFINE_PLUGIN(ProcTMVA)
Common base class for variable processors.
Definition: VarProcessor.h:39