CMS 3D CMS Logo

GBRForestTools.cc
Go to the documentation of this file.
5 
6 #include "TFile.h"
7 
8 #include <cstdio>
9 #include <cstdlib>
10 #include <RVersion.h>
11 #include <cmath>
12 #include <tinyxml2.h>
13 #include <filesystem>
14 
15 namespace {
16 
17  size_t readVariables(tinyxml2::XMLElement* root, const char* key, std::vector<std::string>& names) {
18  size_t n = 0;
19  names.clear();
20 
21  if (root != nullptr) {
22  for (tinyxml2::XMLElement* e = root->FirstChildElement(key); e != nullptr; e = e->NextSiblingElement(key)) {
23  names.push_back(e->Attribute("Expression"));
24  ++n;
25  }
26  }
27 
28  return n;
29  }
30 
31  bool isTerminal(tinyxml2::XMLElement* node) {
32  bool is = true;
33  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
34  is = false;
35  }
36  return is;
37  }
38 
39  unsigned int countIntermediateNodes(tinyxml2::XMLElement* node) {
40  unsigned int count = 0;
41  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
42  count += countIntermediateNodes(e);
43  }
44  return count > 0 ? count + 1 : 0;
45  }
46 
47  unsigned int countTerminalNodes(tinyxml2::XMLElement* node) {
48  unsigned int count = 0;
49  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
50  count += countTerminalNodes(e);
51  }
52  return count > 0 ? count : 1;
53  }
54 
55  void addNode(GBRTree& tree,
56  tinyxml2::XMLElement* node,
57  double scale,
58  bool isRegression,
59  bool useYesNoLeaf,
60  bool adjustboundary,
61  bool isAdaClassifier) {
62  bool nodeIsTerminal = isTerminal(node);
63  if (nodeIsTerminal) {
64  double response = 0.;
65  if (isRegression) {
66  node->QueryDoubleAttribute("res", &response);
67  } else {
68  if (useYesNoLeaf) {
69  node->QueryDoubleAttribute("nType", &response);
70  } else {
71  if (isAdaClassifier) {
72  node->QueryDoubleAttribute("purity", &response);
73  } else {
74  node->QueryDoubleAttribute("res", &response);
75  }
76  }
77  }
78  response *= scale;
79  tree.Responses().push_back(response);
80  } else {
81  int thisidx = tree.CutIndices().size();
82 
83  int selector = 0;
84  float cutval = 0.;
85  bool ctype = false;
86 
87  node->QueryIntAttribute("IVar", &selector);
88  node->QueryFloatAttribute("Cut", &cutval);
89  node->QueryBoolAttribute("cType", &ctype);
90 
91  tree.CutIndices().push_back(static_cast<unsigned char>(selector));
92 
93  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
94  //to reproduce the correct behaviour
95  if (adjustboundary) {
96  cutval = std::nextafter(cutval, std::numeric_limits<float>::lowest());
97  }
98  tree.CutVals().push_back(cutval);
99  tree.LeftIndices().push_back(0);
100  tree.RightIndices().push_back(0);
101 
102  tinyxml2::XMLElement* left = nullptr;
103  tinyxml2::XMLElement* right = nullptr;
104  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
105  if (*(e->Attribute("pos")) == 'l')
106  left = e;
107  else if (*(e->Attribute("pos")) == 'r')
108  right = e;
109  }
110  if (!ctype) {
111  std::swap(left, right);
112  }
113 
114  tree.LeftIndices()[thisidx] = isTerminal(left) ? -tree.Responses().size() : tree.CutIndices().size();
115  addNode(tree, left, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
116 
117  tree.RightIndices()[thisidx] = isTerminal(right) ? -tree.Responses().size() : tree.CutIndices().size();
118  addNode(tree, right, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
119  }
120  }
121 
122  std::unique_ptr<GBRForest> init(const std::string& weightsFileFullPath, std::vector<std::string>& varNames) {
123  //
124  // Load weights file, for ROOT file
125  //
126  if (reco::details::hasEnding(weightsFileFullPath, ".root")) {
127  TFile gbrForestFile(weightsFileFullPath.c_str());
128  std::unique_ptr<GBRForest> up(gbrForestFile.Get<GBRForest>("gbrForest"));
129  std::unique_ptr<std::vector<std::string>> vars(gbrForestFile.Get<std::vector<std::string>>("variableNames"));
130  gbrForestFile.Close("nodelete");
131  if (vars) {
132  varNames = std::move(*vars);
133  }
134  return up;
135  }
136 
137  //
138  // Load weights file, for gzipped or raw xml file
139  //
140  tinyxml2::XMLDocument xmlDoc;
141 
142  using namespace reco::details;
143 
144  if (hasEnding(weightsFileFullPath, ".xml")) {
145  xmlDoc.LoadFile(weightsFileFullPath.c_str());
146  } else if (hasEnding(weightsFileFullPath, ".gz") || hasEnding(weightsFileFullPath, ".gzip")) {
147  char* buffer = readGzipFile(weightsFileFullPath);
148  xmlDoc.Parse(buffer);
149  free(buffer);
150  }
151 
152  tinyxml2::XMLElement* root = xmlDoc.FirstChildElement("MethodSetup");
153  readVariables(root->FirstChildElement("Variables"), "Variable", varNames);
154 
155  // Read in the TMVA general info
156  std::map<std::string, std::string> info;
157  tinyxml2::XMLElement* infoElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("GeneralInfo");
158  if (infoElem == nullptr) {
159  throw cms::Exception("XMLError") << "No GeneralInfo found in " << weightsFileFullPath << " !!\n";
160  }
161  for (tinyxml2::XMLElement* e = infoElem->FirstChildElement("Info"); e != nullptr;
162  e = e->NextSiblingElement("Info")) {
163  const char* name;
164  const char* value;
165  if (tinyxml2::XML_SUCCESS != e->QueryStringAttribute("name", &name)) {
166  throw cms::Exception("XMLERROR") << "no 'name' attribute found in 'Info' element in " << weightsFileFullPath;
167  }
168  if (tinyxml2::XML_SUCCESS != e->QueryStringAttribute("value", &value)) {
169  throw cms::Exception("XMLERROR") << "no 'value' attribute found in 'Info' element in " << weightsFileFullPath;
170  }
171  info[name] = value;
172  }
173 
174  // Read in the TMVA options
175  std::map<std::string, std::string> options;
176  tinyxml2::XMLElement* optionsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Options");
177  if (optionsElem == nullptr) {
178  throw cms::Exception("XMLError") << "No Options found in " << weightsFileFullPath << " !!\n";
179  }
180  for (tinyxml2::XMLElement* e = optionsElem->FirstChildElement("Option"); e != nullptr;
181  e = e->NextSiblingElement("Option")) {
182  const char* name;
183  if (tinyxml2::XML_SUCCESS != e->QueryStringAttribute("name", &name)) {
184  throw cms::Exception("XMLERROR") << "no 'name' attribute found in 'Option' element in " << weightsFileFullPath;
185  }
186  options[name] = e->GetText();
187  }
188 
189  // Get root version number if available
190  int rootTrainingVersion(0);
191  if (info.find("ROOT Release") != info.end()) {
192  std::string s = info["ROOT Release"];
193  rootTrainingVersion = std::stoi(s.substr(s.find('[') + 1, s.find(']') - s.find('[') - 1));
194  }
195 
196  // Get the boosting weights
197  std::vector<double> boostWeights;
198  tinyxml2::XMLElement* weightsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Weights");
199  if (weightsElem == nullptr) {
200  throw cms::Exception("XMLError") << "No Weights found in " << weightsFileFullPath << " !!\n";
201  }
202  bool hasTrees = false;
203  for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
204  e = e->NextSiblingElement("BinaryTree")) {
205  hasTrees = true;
206  double w;
207  if (tinyxml2::XML_SUCCESS != e->QueryDoubleAttribute("boostWeight", &w)) {
208  throw cms::Exception("XMLERROR") << "problem with 'boostWeight' attribute found in 'BinaryTree' element in "
209  << weightsFileFullPath;
210  }
211  boostWeights.push_back(w);
212  }
213  if (!hasTrees) {
214  throw cms::Exception("XMLError") << "No BinaryTrees found in " << weightsFileFullPath << " !!\n";
215  }
216 
217  bool isRegression = info["AnalysisType"] == "Regression";
218 
219  //special handling for non-gradient-boosted (ie ADABoost) classifiers, where tree responses
220  //need to be renormalized after the training for evaluation purposes
221  bool isAdaClassifier = !isRegression && options["BoostType"] != "Grad";
222  bool useYesNoLeaf = isAdaClassifier && options["UseYesNoLeaf"] == "True";
223 
224  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
225  //to reproduce the correct behaviour
226  bool adjustBoundaries =
227  (rootTrainingVersion >= ROOT_VERSION(5, 34, 20) && rootTrainingVersion < ROOT_VERSION(6, 0, 0)) ||
228  rootTrainingVersion >= ROOT_VERSION(6, 2, 0);
229 
230  auto forest = std::make_unique<GBRForest>();
231  forest->SetInitialResponse(isRegression ? boostWeights[0] : 0.);
232 
233  double norm = 0;
234  if (isAdaClassifier) {
235  for (double w : boostWeights) {
236  norm += w;
237  }
238  }
239 
240  forest->Trees().reserve(boostWeights.size());
241  size_t itree = 0;
242  // Loop over tree estimators
243  for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
244  e = e->NextSiblingElement("BinaryTree")) {
245  double scale = isAdaClassifier ? boostWeights[itree] / norm : 1.0;
246 
247  tinyxml2::XMLElement* root = e->FirstChildElement("Node");
248  forest->Trees().push_back(GBRTree(countIntermediateNodes(root), countTerminalNodes(root)));
249  auto& tree = forest->Trees().back();
250 
251  addNode(tree, root, scale, isRegression, useYesNoLeaf, adjustBoundaries, isAdaClassifier);
252 
253  //special case, root node is terminal, create fake intermediate node at root
254  if (tree.CutIndices().empty()) {
255  tree.CutIndices().push_back(0);
256  tree.CutVals().push_back(0);
257  tree.LeftIndices().push_back(0);
258  tree.RightIndices().push_back(0);
259  }
260 
261  ++itree;
262  }
263 
264  return forest;
265  }
266 
267 } // namespace
268 
269 // Create a GBRForest from an XML weight file
270 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile) {
271  std::vector<std::string> varNames;
273 }
274 
275 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile) {
276  std::vector<std::string> varNames;
277  return createGBRForest(weightsFile.fullPath(), varNames);
278 }
279 
280 // Overloaded versions which are taking string vectors by reference to store the variable names in
281 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile, std::vector<std::string>& varNames) {
282  std::unique_ptr<GBRForest> gbrForest;
283 
284  if (weightsFile[0] == '/') {
285  gbrForest = init(weightsFile, varNames);
286  } else {
287  edm::FileInPath weightsFileEdm(weightsFile);
288  gbrForest = init(weightsFileEdm.fullPath(), varNames);
289  }
290  return gbrForest;
291 }
292 
293 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile,
294  std::vector<std::string>& varNames) {
295  return createGBRForest(weightsFile.fullPath(), varNames);
296 }
Definition: BitonicSort.h:7
static const TGPicture * info(bool iBackgroundIsBlack)
bool hasEnding(std::string const &fullString, std::string const &ending)
std::string fullPath() const
Definition: FileInPath.cc:161
T w() const
int init
Definition: HydjetWrapper.h:64
constexpr char const * varNames[]
const std::string names[nVars_]
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightsFile)
Definition: value.py:1
char * readGzipFile(const std::string &weightFile)
Definition: tree.py:1
vars
Definition: DeepTauId.cc:30
def move(src, dest)
Definition: eostools.py:511