CMS 3D CMS Logo

GBRForestTools.cc
Go to the documentation of this file.
5 
6 #include "TFile.h"
7 
8 #include <cstdio>
9 #include <cstdlib>
10 #include <RVersion.h>
11 #include <cmath>
12 #include <tinyxml2.h>
13 
14 namespace {
15 
16  size_t readVariables(tinyxml2::XMLElement* root, const char* key, std::vector<std::string>& names) {
17  size_t n = 0;
18  names.clear();
19 
20  if (root != nullptr) {
21  for (tinyxml2::XMLElement* e = root->FirstChildElement(key); e != nullptr; e = e->NextSiblingElement(key)) {
22  names.push_back(e->Attribute("Expression"));
23  ++n;
24  }
25  }
26 
27  return n;
28  }
29 
30  bool isTerminal(tinyxml2::XMLElement* node) {
31  bool is = true;
32  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
33  is = false;
34  }
35  return is;
36  }
37 
38  unsigned int countIntermediateNodes(tinyxml2::XMLElement* node) {
39  unsigned int count = 0;
40  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
41  count += countIntermediateNodes(e);
42  }
43  return count > 0 ? count + 1 : 0;
44  }
45 
46  unsigned int countTerminalNodes(tinyxml2::XMLElement* node) {
47  unsigned int count = 0;
48  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
49  count += countTerminalNodes(e);
50  }
51  return count > 0 ? count : 1;
52  }
53 
54  void addNode(GBRTree& tree,
55  tinyxml2::XMLElement* node,
56  double scale,
57  bool isRegression,
58  bool useYesNoLeaf,
59  bool adjustboundary,
60  bool isAdaClassifier) {
61  bool nodeIsTerminal = isTerminal(node);
62  if (nodeIsTerminal) {
63  double response = 0.;
64  if (isRegression) {
65  node->QueryDoubleAttribute("res", &response);
66  } else {
67  if (useYesNoLeaf) {
68  node->QueryDoubleAttribute("nType", &response);
69  } else {
70  if (isAdaClassifier) {
71  node->QueryDoubleAttribute("purity", &response);
72  } else {
73  node->QueryDoubleAttribute("res", &response);
74  }
75  }
76  }
77  response *= scale;
78  tree.Responses().push_back(response);
79  } else {
80  int thisidx = tree.CutIndices().size();
81 
82  int selector;
83  float cutval;
84  bool ctype;
85 
86  node->QueryIntAttribute("IVar", &selector);
87  node->QueryFloatAttribute("Cut", &cutval);
88  node->QueryBoolAttribute("cType", &ctype);
89 
90  tree.CutIndices().push_back(static_cast<unsigned char>(selector));
91 
92  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
93  //to reproduce the correct behaviour
94  if (adjustboundary) {
95  cutval = std::nextafter(cutval, std::numeric_limits<float>::lowest());
96  }
97  tree.CutVals().push_back(cutval);
98  tree.LeftIndices().push_back(0);
99  tree.RightIndices().push_back(0);
100 
101  tinyxml2::XMLElement* left = nullptr;
102  tinyxml2::XMLElement* right = nullptr;
103  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
104  if (*(e->Attribute("pos")) == 'l')
105  left = e;
106  else if (*(e->Attribute("pos")) == 'r')
107  right = e;
108  }
109  if (!ctype) {
110  std::swap(left, right);
111  }
112 
113  tree.LeftIndices()[thisidx] = isTerminal(left) ? -tree.Responses().size() : tree.CutIndices().size();
114  addNode(tree, left, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
115 
116  tree.RightIndices()[thisidx] = isTerminal(right) ? -tree.Responses().size() : tree.CutIndices().size();
117  addNode(tree, right, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
118  }
119  }
120 
121  std::unique_ptr<GBRForest> init(const std::string& weightsFileFullPath, std::vector<std::string>& varNames) {
122  //
123  // Load weights file, for ROOT file
124  //
125  if (reco::details::hasEnding(weightsFileFullPath, ".root")) {
126  TFile gbrForestFile(weightsFileFullPath.c_str());
127  std::unique_ptr<GBRForest> up(reinterpret_cast<GBRForest*>(gbrForestFile.Get("gbrForest")));
128  gbrForestFile.Close("nodelete");
129  return up;
130  }
131 
132  //
133  // Load weights file, for gzipped or raw xml file
134  //
135  tinyxml2::XMLDocument xmlDoc;
136 
137  using namespace reco::details;
138 
139  if (hasEnding(weightsFileFullPath, ".xml")) {
140  xmlDoc.LoadFile(weightsFileFullPath.c_str());
141  } else if (hasEnding(weightsFileFullPath, ".gz") || hasEnding(weightsFileFullPath, ".gzip")) {
142  char* buffer = readGzipFile(weightsFileFullPath);
143  xmlDoc.Parse(buffer);
144  free(buffer);
145  }
146 
147  tinyxml2::XMLElement* root = xmlDoc.FirstChildElement("MethodSetup");
148  readVariables(root->FirstChildElement("Variables"), "Variable", varNames);
149 
150  // Read in the TMVA general info
151  std::map<std::string, std::string> info;
152  tinyxml2::XMLElement* infoElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("GeneralInfo");
153  if (infoElem == nullptr) {
154  throw cms::Exception("XMLError") << "No GeneralInfo found in " << weightsFileFullPath << " !!\n";
155  }
156  for (tinyxml2::XMLElement* e = infoElem->FirstChildElement("Info"); e != nullptr;
157  e = e->NextSiblingElement("Info")) {
158  const char* name;
159  const char* value;
160  e->QueryStringAttribute("name", &name);
161  e->QueryStringAttribute("value", &value);
162  info[name] = value;
163  }
164 
165  // Read in the TMVA options
166  std::map<std::string, std::string> options;
167  tinyxml2::XMLElement* optionsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Options");
168  if (optionsElem == nullptr) {
169  throw cms::Exception("XMLError") << "No Options found in " << weightsFileFullPath << " !!\n";
170  }
171  for (tinyxml2::XMLElement* e = optionsElem->FirstChildElement("Option"); e != nullptr;
172  e = e->NextSiblingElement("Option")) {
173  const char* name;
174  e->QueryStringAttribute("name", &name);
175  options[name] = e->GetText();
176  }
177 
178  // Get root version number if available
179  int rootTrainingVersion(0);
180  if (info.find("ROOT Release") != info.end()) {
181  std::string s = info["ROOT Release"];
182  rootTrainingVersion = std::stoi(s.substr(s.find("[") + 1, s.find("]") - s.find("[") - 1));
183  }
184 
185  // Get the boosting weights
186  std::vector<double> boostWeights;
187  tinyxml2::XMLElement* weightsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Weights");
188  if (weightsElem == nullptr) {
189  throw cms::Exception("XMLError") << "No Weights found in " << weightsFileFullPath << " !!\n";
190  }
191  bool hasTrees = false;
192  for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
193  e = e->NextSiblingElement("BinaryTree")) {
194  hasTrees = true;
195  double w;
196  e->QueryDoubleAttribute("boostWeight", &w);
197  boostWeights.push_back(w);
198  }
199  if (!hasTrees) {
200  throw cms::Exception("XMLError") << "No BinaryTrees found in " << weightsFileFullPath << " !!\n";
201  }
202 
203  bool isRegression = info["AnalysisType"] == "Regression";
204 
205  //special handling for non-gradient-boosted (ie ADABoost) classifiers, where tree responses
206  //need to be renormalized after the training for evaluation purposes
207  bool isAdaClassifier = !isRegression && options["BoostType"] != "Grad";
208  bool useYesNoLeaf = isAdaClassifier && options["UseYesNoLeaf"] == "True";
209 
210  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
211  //to reproduce the correct behaviour
212  bool adjustBoundaries =
213  (rootTrainingVersion >= ROOT_VERSION(5, 34, 20) && rootTrainingVersion < ROOT_VERSION(6, 0, 0)) ||
214  rootTrainingVersion >= ROOT_VERSION(6, 2, 0);
215 
216  auto forest = std::make_unique<GBRForest>();
217  forest->SetInitialResponse(isRegression ? boostWeights[0] : 0.);
218 
219  double norm = 0;
220  if (isAdaClassifier) {
221  for (double w : boostWeights) {
222  norm += w;
223  }
224  }
225 
226  forest->Trees().reserve(boostWeights.size());
227  size_t itree = 0;
228  // Loop over tree estimators
229  for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
230  e = e->NextSiblingElement("BinaryTree")) {
231  double scale = isAdaClassifier ? boostWeights[itree] / norm : 1.0;
232 
233  tinyxml2::XMLElement* root = e->FirstChildElement("Node");
234  forest->Trees().push_back(GBRTree(countIntermediateNodes(root), countTerminalNodes(root)));
235  auto& tree = forest->Trees().back();
236 
237  addNode(tree, root, scale, isRegression, useYesNoLeaf, adjustBoundaries, isAdaClassifier);
238 
239  //special case, root node is terminal, create fake intermediate node at root
240  if (tree.CutIndices().empty()) {
241  tree.CutIndices().push_back(0);
242  tree.CutVals().push_back(0);
243  tree.LeftIndices().push_back(0);
244  tree.RightIndices().push_back(0);
245  }
246 
247  ++itree;
248  }
249 
250  return forest;
251  }
252 
253 } // namespace
254 
255 // Create a GBRForest from an XML weight file
256 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile) {
257  std::vector<std::string> varNames;
258  return createGBRForest(weightsFile, varNames);
259 }
260 
261 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile) {
262  std::vector<std::string> varNames;
263  return createGBRForest(weightsFile.fullPath(), varNames);
264 }
265 
266 // Overloaded versions which are taking string vectors by reference to store the variable names in
267 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile, std::vector<std::string>& varNames) {
268  std::unique_ptr<GBRForest> gbrForest;
269 
270  if (weightsFile[0] == '/') {
271  gbrForest = init(weightsFile, varNames);
272  } else {
273  edm::FileInPath weightsFileEdm(weightsFile);
274  gbrForest = init(weightsFileEdm.fullPath(), varNames);
275  }
276  return gbrForest;
277 }
278 
279 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile,
280  std::vector<std::string>& varNames) {
281  return createGBRForest(weightsFile.fullPath(), varNames);
282 }
Definition: BitonicSort.h:8
static const TGPicture * info(bool iBackgroundIsBlack)
std::vector< float > & Responses()
Definition: GBRTree.h:39
const double w
Definition: UKUtility.cc:23
bool hasEnding(std::string const &fullString, std::string const &ending)
int init
Definition: HydjetWrapper.h:67
const std::string names[nVars_]
std::vector< float > & CutVals()
Definition: GBRTree.h:45
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightsFile)
std::vector< int > & LeftIndices()
Definition: GBRTree.h:48
char * readGzipFile(const std::string &weightFile)
char const * varNames[]
std::string fullPath() const
Definition: FileInPath.cc:163
Definition: tree.py:1
std::vector< int > & RightIndices()
Definition: GBRTree.h:51
std::vector< unsigned char > & CutIndices()
Definition: GBRTree.h:42