CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
GBRForestTools.cc
Go to the documentation of this file.
5 
6 #include "TFile.h"
7 
8 #include <cstdio>
9 #include <cstdlib>
10 #include <RVersion.h>
11 #include <cmath>
12 #include <tinyxml2.h>
13 #include <filesystem>
14 
15 namespace {
16 
17  size_t readVariables(tinyxml2::XMLElement* root, const char* key, std::vector<std::string>& names) {
18  size_t n = 0;
19  names.clear();
20 
21  if (root != nullptr) {
22  for (tinyxml2::XMLElement* e = root->FirstChildElement(key); e != nullptr; e = e->NextSiblingElement(key)) {
23  names.push_back(e->Attribute("Expression"));
24  ++n;
25  }
26  }
27 
28  return n;
29  }
30 
31  bool isTerminal(tinyxml2::XMLElement* node) {
32  bool is = true;
33  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
34  is = false;
35  }
36  return is;
37  }
38 
39  unsigned int countIntermediateNodes(tinyxml2::XMLElement* node) {
40  unsigned int count = 0;
41  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
42  count += countIntermediateNodes(e);
43  }
44  return count > 0 ? count + 1 : 0;
45  }
46 
47  unsigned int countTerminalNodes(tinyxml2::XMLElement* node) {
48  unsigned int count = 0;
49  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
50  count += countTerminalNodes(e);
51  }
52  return count > 0 ? count : 1;
53  }
54 
55  void addNode(GBRTree& tree,
56  tinyxml2::XMLElement* node,
57  double scale,
58  bool isRegression,
59  bool useYesNoLeaf,
60  bool adjustboundary,
61  bool isAdaClassifier) {
62  bool nodeIsTerminal = isTerminal(node);
63  if (nodeIsTerminal) {
64  double response = 0.;
65  if (isRegression) {
66  node->QueryDoubleAttribute("res", &response);
67  } else {
68  if (useYesNoLeaf) {
69  node->QueryDoubleAttribute("nType", &response);
70  } else {
71  if (isAdaClassifier) {
72  node->QueryDoubleAttribute("purity", &response);
73  } else {
74  node->QueryDoubleAttribute("res", &response);
75  }
76  }
77  }
78  response *= scale;
79  tree.Responses().push_back(response);
80  } else {
81  int thisidx = tree.CutIndices().size();
82 
83  int selector = 0;
84  float cutval = 0.;
85  bool ctype = false;
86 
87  node->QueryIntAttribute("IVar", &selector);
88  node->QueryFloatAttribute("Cut", &cutval);
89  node->QueryBoolAttribute("cType", &ctype);
90 
91  tree.CutIndices().push_back(static_cast<unsigned char>(selector));
92 
93  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
94  //to reproduce the correct behaviour
95  if (adjustboundary) {
96  cutval = std::nextafter(cutval, std::numeric_limits<float>::lowest());
97  }
98  tree.CutVals().push_back(cutval);
99  tree.LeftIndices().push_back(0);
100  tree.RightIndices().push_back(0);
101 
102  tinyxml2::XMLElement* left = nullptr;
103  tinyxml2::XMLElement* right = nullptr;
104  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
105  if (*(e->Attribute("pos")) == 'l')
106  left = e;
107  else if (*(e->Attribute("pos")) == 'r')
108  right = e;
109  }
110  if (!ctype) {
111  std::swap(left, right);
112  }
113 
114  tree.LeftIndices()[thisidx] = isTerminal(left) ? -tree.Responses().size() : tree.CutIndices().size();
115  addNode(tree, left, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
116 
117  tree.RightIndices()[thisidx] = isTerminal(right) ? -tree.Responses().size() : tree.CutIndices().size();
118  addNode(tree, right, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
119  }
120  }
121 
122  std::unique_ptr<GBRForest> init(const std::string& weightsFileFullPath, std::vector<std::string>& varNames) {
123  //
124  // Load weights file, for ROOT file
125  //
126  if (reco::details::hasEnding(weightsFileFullPath, ".root")) {
127  TFile gbrForestFile(weightsFileFullPath.c_str());
128  std::unique_ptr<GBRForest> up(gbrForestFile.Get<GBRForest>("gbrForest"));
129  gbrForestFile.Close("nodelete");
130  return up;
131  }
132 
133  //
134  // Load weights file, for gzipped or raw xml file
135  //
136  tinyxml2::XMLDocument xmlDoc;
137 
138  using namespace reco::details;
139 
140  if (hasEnding(weightsFileFullPath, ".xml")) {
141  xmlDoc.LoadFile(weightsFileFullPath.c_str());
142  } else if (hasEnding(weightsFileFullPath, ".gz") || hasEnding(weightsFileFullPath, ".gzip")) {
143  char* buffer = readGzipFile(weightsFileFullPath);
144  xmlDoc.Parse(buffer);
145  free(buffer);
146  }
147 
148  tinyxml2::XMLElement* root = xmlDoc.FirstChildElement("MethodSetup");
149  readVariables(root->FirstChildElement("Variables"), "Variable", varNames);
150 
151  // Read in the TMVA general info
152  std::map<std::string, std::string> info;
153  tinyxml2::XMLElement* infoElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("GeneralInfo");
154  if (infoElem == nullptr) {
155  throw cms::Exception("XMLError") << "No GeneralInfo found in " << weightsFileFullPath << " !!\n";
156  }
157  for (tinyxml2::XMLElement* e = infoElem->FirstChildElement("Info"); e != nullptr;
158  e = e->NextSiblingElement("Info")) {
159  const char* name;
160  const char* value;
161  if (tinyxml2::XML_SUCCESS != e->QueryStringAttribute("name", &name)) {
162  throw cms::Exception("XMLERROR") << "no 'name' attribute found in 'Info' element in " << weightsFileFullPath;
163  }
164  if (tinyxml2::XML_SUCCESS != e->QueryStringAttribute("value", &value)) {
165  throw cms::Exception("XMLERROR") << "no 'value' attribute found in 'Info' element in " << weightsFileFullPath;
166  }
167  info[name] = value;
168  }
169 
170  // Read in the TMVA options
171  std::map<std::string, std::string> options;
172  tinyxml2::XMLElement* optionsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Options");
173  if (optionsElem == nullptr) {
174  throw cms::Exception("XMLError") << "No Options found in " << weightsFileFullPath << " !!\n";
175  }
176  for (tinyxml2::XMLElement* e = optionsElem->FirstChildElement("Option"); e != nullptr;
177  e = e->NextSiblingElement("Option")) {
178  const char* name;
179  if (tinyxml2::XML_SUCCESS != e->QueryStringAttribute("name", &name)) {
180  throw cms::Exception("XMLERROR") << "no 'name' attribute found in 'Option' element in " << weightsFileFullPath;
181  }
182  options[name] = e->GetText();
183  }
184 
185  // Get root version number if available
186  int rootTrainingVersion(0);
187  if (info.find("ROOT Release") != info.end()) {
188  std::string s = info["ROOT Release"];
189  rootTrainingVersion = std::stoi(s.substr(s.find('[') + 1, s.find(']') - s.find('[') - 1));
190  }
191 
192  // Get the boosting weights
193  std::vector<double> boostWeights;
194  tinyxml2::XMLElement* weightsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Weights");
195  if (weightsElem == nullptr) {
196  throw cms::Exception("XMLError") << "No Weights found in " << weightsFileFullPath << " !!\n";
197  }
198  bool hasTrees = false;
199  for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
200  e = e->NextSiblingElement("BinaryTree")) {
201  hasTrees = true;
202  double w;
203  if (tinyxml2::XML_SUCCESS != e->QueryDoubleAttribute("boostWeight", &w)) {
204  throw cms::Exception("XMLERROR") << "problem with 'boostWeight' attribute found in 'BinaryTree' element in "
205  << weightsFileFullPath;
206  }
207  boostWeights.push_back(w);
208  }
209  if (!hasTrees) {
210  throw cms::Exception("XMLError") << "No BinaryTrees found in " << weightsFileFullPath << " !!\n";
211  }
212 
213  bool isRegression = info["AnalysisType"] == "Regression";
214 
215  //special handling for non-gradient-boosted (ie ADABoost) classifiers, where tree responses
216  //need to be renormalized after the training for evaluation purposes
217  bool isAdaClassifier = !isRegression && options["BoostType"] != "Grad";
218  bool useYesNoLeaf = isAdaClassifier && options["UseYesNoLeaf"] == "True";
219 
220  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
221  //to reproduce the correct behaviour
222  bool adjustBoundaries =
223  (rootTrainingVersion >= ROOT_VERSION(5, 34, 20) && rootTrainingVersion < ROOT_VERSION(6, 0, 0)) ||
224  rootTrainingVersion >= ROOT_VERSION(6, 2, 0);
225 
226  auto forest = std::make_unique<GBRForest>();
227  forest->SetInitialResponse(isRegression ? boostWeights[0] : 0.);
228 
229  double norm = 0;
230  if (isAdaClassifier) {
231  for (double w : boostWeights) {
232  norm += w;
233  }
234  }
235 
236  forest->Trees().reserve(boostWeights.size());
237  size_t itree = 0;
238  // Loop over tree estimators
239  for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
240  e = e->NextSiblingElement("BinaryTree")) {
241  double scale = isAdaClassifier ? boostWeights[itree] / norm : 1.0;
242 
243  tinyxml2::XMLElement* root = e->FirstChildElement("Node");
244  forest->Trees().push_back(GBRTree(countIntermediateNodes(root), countTerminalNodes(root)));
245  auto& tree = forest->Trees().back();
246 
247  addNode(tree, root, scale, isRegression, useYesNoLeaf, adjustBoundaries, isAdaClassifier);
248 
249  //special case, root node is terminal, create fake intermediate node at root
250  if (tree.CutIndices().empty()) {
251  tree.CutIndices().push_back(0);
252  tree.CutVals().push_back(0);
253  tree.LeftIndices().push_back(0);
254  tree.RightIndices().push_back(0);
255  }
256 
257  ++itree;
258  }
259 
260  return forest;
261  }
262 
263 } // namespace
264 
265 // Create a GBRForest from an XML weight file
266 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile) {
267  std::vector<std::string> varNames;
268  return createGBRForest(weightsFile, varNames);
269 }
270 
271 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile) {
272  std::vector<std::string> varNames;
273  return createGBRForest(weightsFile.fullPath(), varNames);
274 }
275 
276 // Overloaded versions which are taking string vectors by reference to store the variable names in
277 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile, std::vector<std::string>& varNames) {
278  std::unique_ptr<GBRForest> gbrForest;
279 
280  if (weightsFile[0] == '/') {
281  gbrForest = init(weightsFile, varNames);
282  } else {
283  edm::FileInPath weightsFileEdm(weightsFile);
284  gbrForest = init(weightsFileEdm.fullPath(), varNames);
285  }
286  return gbrForest;
287 }
288 
289 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile,
290  std::vector<std::string>& varNames) {
291  return createGBRForest(weightsFile.fullPath(), varNames);
292 }
Definition: BitonicSort.h:7
static const TGPicture * info(bool iBackgroundIsBlack)
std::vector< float > & Responses()
Definition: GBRTree.h:36
const double w
Definition: UKUtility.cc:23
bool hasEnding(std::string const &fullString, std::string const &ending)
int init
Definition: HydjetWrapper.h:64
constexpr char const * varNames[]
const std::string names[nVars_]
std::vector< float > & CutVals()
Definition: GBRTree.h:42
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
tuple key
prepare the HTCondor submission files and eventually submit them
std::vector< int > & LeftIndices()
Definition: GBRTree.h:45
char * readGzipFile(const std::string &weightFile)
std::string fullPath() const
Definition: FileInPath.cc:161
std::vector< int > & RightIndices()
Definition: GBRTree.h:48
std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightsFile)
std::vector< unsigned char > & CutIndices()
Definition: GBRTree.h:39