CMS 3D CMS Logo

GBRForestTools.cc
Go to the documentation of this file.
5 
6 #include <cstdio>
7 #include <cstdlib>
8 #include <RVersion.h>
9 #include <cmath>
10 #include <tinyxml2.h>
11 
12 namespace {
13 
14  size_t readVariables(tinyxml2::XMLElement* root, const char* key, std::vector<std::string>& names) {
15  size_t n = 0;
16  names.clear();
17 
18  if (root != nullptr) {
19  for (tinyxml2::XMLElement* e = root->FirstChildElement(key); e != nullptr; e = e->NextSiblingElement(key)) {
20  names.push_back(e->Attribute("Expression"));
21  ++n;
22  }
23  }
24 
25  return n;
26  }
27 
28  bool isTerminal(tinyxml2::XMLElement* node) {
29  bool is = true;
30  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
31  is = false;
32  }
33  return is;
34  }
35 
36  unsigned int countIntermediateNodes(tinyxml2::XMLElement* node) {
37  unsigned int count = 0;
38  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
39  count += countIntermediateNodes(e);
40  }
41  return count > 0 ? count + 1 : 0;
42  }
43 
44  unsigned int countTerminalNodes(tinyxml2::XMLElement* node) {
45  unsigned int count = 0;
46  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
47  count += countTerminalNodes(e);
48  }
49  return count > 0 ? count : 1;
50  }
51 
52  void addNode(GBRTree& tree,
53  tinyxml2::XMLElement* node,
54  double scale,
55  bool isRegression,
56  bool useYesNoLeaf,
57  bool adjustboundary,
58  bool isAdaClassifier) {
59  bool nodeIsTerminal = isTerminal(node);
60  if (nodeIsTerminal) {
61  double response = 0.;
62  if (isRegression) {
63  node->QueryDoubleAttribute("res", &response);
64  } else {
65  if (useYesNoLeaf) {
66  node->QueryDoubleAttribute("nType", &response);
67  } else {
68  if (isAdaClassifier) {
69  node->QueryDoubleAttribute("purity", &response);
70  } else {
71  node->QueryDoubleAttribute("res", &response);
72  }
73  }
74  }
75  response *= scale;
76  tree.Responses().push_back(response);
77  } else {
78  int thisidx = tree.CutIndices().size();
79 
80  int selector;
81  float cutval;
82  bool ctype;
83 
84  node->QueryIntAttribute("IVar", &selector);
85  node->QueryFloatAttribute("Cut", &cutval);
86  node->QueryBoolAttribute("cType", &ctype);
87 
88  tree.CutIndices().push_back(static_cast<unsigned char>(selector));
89 
90  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
91  //to reproduce the correct behaviour
92  if (adjustboundary) {
93  cutval = std::nextafter(cutval, std::numeric_limits<float>::lowest());
94  }
95  tree.CutVals().push_back(cutval);
96  tree.LeftIndices().push_back(0);
97  tree.RightIndices().push_back(0);
98 
99  tinyxml2::XMLElement* left = nullptr;
100  tinyxml2::XMLElement* right = nullptr;
101  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
102  if (*(e->Attribute("pos")) == 'l')
103  left = e;
104  else if (*(e->Attribute("pos")) == 'r')
105  right = e;
106  }
107  if (!ctype) {
108  std::swap(left, right);
109  }
110 
111  tree.LeftIndices()[thisidx] = isTerminal(left) ? -tree.Responses().size() : tree.CutIndices().size();
112  addNode(tree, left, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
113 
114  tree.RightIndices()[thisidx] = isTerminal(right) ? -tree.Responses().size() : tree.CutIndices().size();
115  addNode(tree, right, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
116  }
117  }
118 
119  std::unique_ptr<GBRForest> init(const std::string& weightsFileFullPath, std::vector<std::string>& varNames) {
120  //
121  // Load weights file, for gzipped or raw xml file
122  //
123  tinyxml2::XMLDocument xmlDoc;
124 
125  using namespace reco::details;
126 
127  if (hasEnding(weightsFileFullPath, ".xml")) {
128  xmlDoc.LoadFile(weightsFileFullPath.c_str());
129  } else if (hasEnding(weightsFileFullPath, ".gz") || hasEnding(weightsFileFullPath, ".gzip")) {
130  char* buffer = readGzipFile(weightsFileFullPath);
131  xmlDoc.Parse(buffer);
132  free(buffer);
133  }
134 
135  tinyxml2::XMLElement* root = xmlDoc.FirstChildElement("MethodSetup");
136  readVariables(root->FirstChildElement("Variables"), "Variable", varNames);
137 
138  // Read in the TMVA general info
139  std::map<std::string, std::string> info;
140  tinyxml2::XMLElement* infoElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("GeneralInfo");
141  if (infoElem == nullptr) {
142  throw cms::Exception("XMLError") << "No GeneralInfo found in " << weightsFileFullPath << " !!\n";
143  }
144  for (tinyxml2::XMLElement* e = infoElem->FirstChildElement("Info"); e != nullptr;
145  e = e->NextSiblingElement("Info")) {
146  const char* name;
147  const char* value;
148  e->QueryStringAttribute("name", &name);
149  e->QueryStringAttribute("value", &value);
150  info[name] = value;
151  }
152 
153  // Read in the TMVA options
154  std::map<std::string, std::string> options;
155  tinyxml2::XMLElement* optionsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Options");
156  if (optionsElem == nullptr) {
157  throw cms::Exception("XMLError") << "No Options found in " << weightsFileFullPath << " !!\n";
158  }
159  for (tinyxml2::XMLElement* e = optionsElem->FirstChildElement("Option"); e != nullptr;
160  e = e->NextSiblingElement("Option")) {
161  const char* name;
162  e->QueryStringAttribute("name", &name);
163  options[name] = e->GetText();
164  }
165 
166  // Get root version number if available
167  int rootTrainingVersion(0);
168  if (info.find("ROOT Release") != info.end()) {
169  std::string s = info["ROOT Release"];
170  rootTrainingVersion = std::stoi(s.substr(s.find("[") + 1, s.find("]") - s.find("[") - 1));
171  }
172 
173  // Get the boosting weights
174  std::vector<double> boostWeights;
175  tinyxml2::XMLElement* weightsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Weights");
176  if (weightsElem == nullptr) {
177  throw cms::Exception("XMLError") << "No Weights found in " << weightsFileFullPath << " !!\n";
178  }
179  bool hasTrees = false;
180  for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
181  e = e->NextSiblingElement("BinaryTree")) {
182  hasTrees = true;
183  double w;
184  e->QueryDoubleAttribute("boostWeight", &w);
185  boostWeights.push_back(w);
186  }
187  if (!hasTrees) {
188  throw cms::Exception("XMLError") << "No BinaryTrees found in " << weightsFileFullPath << " !!\n";
189  }
190 
191  bool isRegression = info["AnalysisType"] == "Regression";
192 
193  //special handling for non-gradient-boosted (ie ADABoost) classifiers, where tree responses
194  //need to be renormalized after the training for evaluation purposes
195  bool isAdaClassifier = !isRegression && options["BoostType"] != "Grad";
196  bool useYesNoLeaf = isAdaClassifier && options["UseYesNoLeaf"] == "True";
197 
198  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
199  //to reproduce the correct behaviour
200  bool adjustBoundaries =
201  (rootTrainingVersion >= ROOT_VERSION(5, 34, 20) && rootTrainingVersion < ROOT_VERSION(6, 0, 0)) ||
202  rootTrainingVersion >= ROOT_VERSION(6, 2, 0);
203 
204  auto forest = std::make_unique<GBRForest>();
205  forest->SetInitialResponse(isRegression ? boostWeights[0] : 0.);
206 
207  double norm = 0;
208  if (isAdaClassifier) {
209  for (double w : boostWeights) {
210  norm += w;
211  }
212  }
213 
214  forest->Trees().reserve(boostWeights.size());
215  size_t itree = 0;
216  // Loop over tree estimators
217  for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
218  e = e->NextSiblingElement("BinaryTree")) {
219  double scale = isAdaClassifier ? boostWeights[itree] / norm : 1.0;
220 
221  tinyxml2::XMLElement* root = e->FirstChildElement("Node");
222  forest->Trees().push_back(GBRTree(countIntermediateNodes(root), countTerminalNodes(root)));
223  auto& tree = forest->Trees().back();
224 
225  addNode(tree, root, scale, isRegression, useYesNoLeaf, adjustBoundaries, isAdaClassifier);
226 
227  //special case, root node is terminal, create fake intermediate node at root
228  if (tree.CutIndices().empty()) {
229  tree.CutIndices().push_back(0);
230  tree.CutVals().push_back(0);
231  tree.LeftIndices().push_back(0);
232  tree.RightIndices().push_back(0);
233  }
234 
235  ++itree;
236  }
237 
238  return forest;
239  }
240 
241 } // namespace
242 
243 // Create a GBRForest from an XML weight file
244 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile) {
245  std::vector<std::string> varNames;
246  return createGBRForest(weightsFile, varNames);
247 }
248 
249 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile) {
250  std::vector<std::string> varNames;
251  return createGBRForest(weightsFile.fullPath(), varNames);
252 }
253 
254 // Overloaded versions which are taking string vectors by reference to store the variable names in
255 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile, std::vector<std::string>& varNames) {
256  std::unique_ptr<GBRForest> gbrForest;
257 
258  if (weightsFile[0] == '/') {
259  gbrForest = init(weightsFile, varNames);
260  } else {
261  edm::FileInPath weightsFileEdm(weightsFile);
262  gbrForest = init(weightsFileEdm.fullPath(), varNames);
263  }
264  return gbrForest;
265 }
266 
267 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile,
268  std::vector<std::string>& varNames) {
269  return createGBRForest(weightsFile.fullPath(), varNames);
270 }
static const TGPicture * info(bool iBackgroundIsBlack)
std::vector< float > & Responses()
Definition: GBRTree.h:39
const double w
Definition: UKUtility.cc:23
bool hasEnding(std::string const &fullString, std::string const &ending)
int init
Definition: HydjetWrapper.h:67
const std::string names[nVars_]
std::vector< float > & CutVals()
Definition: GBRTree.h:45
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightsFile)
std::vector< int > & LeftIndices()
Definition: GBRTree.h:48
char * readGzipFile(const std::string &weightFile)
char const * varNames[]
std::string fullPath() const
Definition: FileInPath.cc:163
Definition: tree.py:1
std::vector< int > & RightIndices()
Definition: GBRTree.h:51
std::vector< unsigned char > & CutIndices()
Definition: GBRTree.h:42