CMS 3D CMS Logo

GBRForestTools.cc
Go to the documentation of this file.
2 
3 #include <iostream>
4 #include <fstream>
5 
6 namespace {
7 
8  // Will return position of n-th occurence of a char in a string.
9  int strpos(const std::string &haystack, char needle, unsigned int nth)
10  {
11  int found = 0;
12  for (unsigned int i=0 ; i<nth ; ++i) {
13  std::size_t pos = haystack.find(needle, found);
14  if (pos == std::string::npos) return -1;
15  else found = pos+1;
16  }
17  return found;
18  }
19 
20  // To get the substring between the n1th and n2th quotation mark in a string
21  std::string get_quoted_substring(const std::string &str, int n1, int n2)
22  {
23  int pos = strpos(str, '"', n1);
24  int count = strpos(str, '"', n2) - pos;
25  return str.substr(pos, count - 1);
26  }
27 
28 };
29 
30 std::unique_ptr<const GBRForest> GBRForestTools::createGBRForest(const std::string &weightFile,
31  std::vector<std::string> &varNames){
32  edm::FileInPath weightFileEdm(weightFile);
33  return GBRForestTools::createGBRForest(weightFileEdm, varNames);
34 }
35 
36 // Creates a pointer to new GBRForest corresponding to a TMVA weights file
37 std::unique_ptr<const GBRForest> GBRForestTools::createGBRForest(const edm::FileInPath &weightFile,
38  std::vector<std::string> &varNames){
39 
41 
42  unsigned int NVar = 0;
43  unsigned int NSpec = 0;
44 
45  std::vector<float> dumbVars;
46  std::vector<float> dumbSpecs;
47 
48  varNames.clear();
49  std::vector<std::string> specNames;
50 
52  std::ifstream f;
53  std::string tmpstr;
54 
55  bool gzipped = false;
56 
57  //
58  // Set up the input buffers, for gzipped or raw xml file
59  //
60  if (reco::details::hasEnding(weightFile.fullPath(), ".xml")) {
61  f.open(weightFile.fullPath());
62  tmpstr = "";
63  } else if (reco::details::hasEnding(weightFile.fullPath(), ".gz") || reco::details::hasEnding(weightFile.fullPath(), ".gzip")) {
64  gzipped = true;
65  char *buffer = reco::details::readGzipFile(weightFile.fullPath());
66  tmpstr = std::string(buffer);
67  free(buffer);
68  }
69  std::stringstream is(tmpstr);
70 
71  bool isend;
72 
73  while(true) {
74 
75  if (gzipped) isend = !std::getline(is, line);
76  else isend = !std::getline(f, line);
77 
78  if (isend) break;
79 
80  // Terminate reading of weights file
81  if (line.find("<Weights ") != std::string::npos) break;
82 
83  // Method name
84  else if (line.find("<MethodSetup Method=") != std::string::npos) {
85  method = get_quoted_substring(line, 1, 2);
86  }
87 
88  // Number of variables
89  else if (line.find("<Variables NVar=") != std::string::npos) {
90  NVar = std::atoi(get_quoted_substring(line, 1, 2).c_str());
91  }
92 
93  // Number of spectators
94  else if (line.find("<Spectators NSpec=") != std::string::npos) {
95  NSpec = std::atoi(get_quoted_substring(line, 1, 2).c_str());
96  }
97 
98  // If variable
99  else if (line.find("<Variable ") != std::string::npos) {
100  unsigned int pos = line.find("Expression=");
101  varNames.push_back(get_quoted_substring(line.substr(pos, line.length() - pos), 1, 2));
102  dumbVars.push_back(0);
103  }
104 
105  // If spectator
106  else if (line.find("Spectator ") != std::string::npos) {
107  unsigned int pos = line.find("Expression=");
108  specNames.push_back(get_quoted_substring(line.substr(pos, line.length() - pos), 1, 2));
109  dumbSpecs.push_back(0);
110  }
111  }
112 
113  //
114  // Create the reader
115  //
116  TMVA::Reader* mvaReader = new TMVA::Reader("!Color:Silent:!Error");
117 
118  //
119  // Configure all variables and spectators. Note: the order and names
120  // must match what is found in the xml weights file!
121  //
122  for(size_t i = 0; i < NVar; ++i){
123  mvaReader->AddVariable(varNames[i], &dumbVars[i]);
124  }
125 
126  for(size_t i = 0; i < NSpec; ++i){
127  mvaReader->AddSpectator(specNames[i], &dumbSpecs[i]);
128  }
129 
130  //
131  // Book the method and set up the weights file
132  //
133 
134  reco::details::loadTMVAWeights(mvaReader, method, weightFile.fullPath());
135 
136  TMVA::MethodBDT* bdt = dynamic_cast<TMVA::MethodBDT*>( mvaReader->FindMVA(method) );
137  std::unique_ptr<const GBRForest> gbrForest = std::make_unique<const GBRForest>(GBRForest(bdt));
138  delete mvaReader;
139 
140  return gbrForest;
141 }
142 
143 std::unique_ptr<const GBRForest> GBRForestTools::createGBRForest(const std::string &weightFile){
144  std::vector<std::string> varNames;
145  return GBRForestTools::createGBRForest(weightFile, varNames);
146 }
147 
148 std::unique_ptr<const GBRForest> GBRForestTools::createGBRForest(const edm::FileInPath &weightFile){
149  std::vector<std::string> varNames;
150  return GBRForestTools::createGBRForest(weightFile, varNames);
151 }
bool hasEnding(std::string const &fullString, std::string const &ending)
static std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightFile)
double f[11][100]
char * readGzipFile(const std::string &weightFile)
char const * varNames[]
TMVA::IMethod * loadTMVAWeights(TMVA::Reader *reader, const std::string &method, const std::string &weightFile, bool verbose=false)
std::string fullPath() const
Definition: FileInPath.cc:197
#define str(s)