CMS 3D CMS Logo

GBRForestTools.cc
Go to the documentation of this file.
5 
6 #include <cstdio>
7 #include <cstdlib>
8 #include <RVersion.h>
9 #include <cmath>
10 #include <tinyxml2.h>
11 
12 namespace {
13 
14  size_t readVariables(tinyxml2::XMLElement* root, const char * key, std::vector<std::string>& names)
15  {
16  size_t n = 0;
17  names.clear();
18 
19  if (root != nullptr) {
20  for(tinyxml2::XMLElement* e = root->FirstChildElement(key);
21  e != nullptr; e = e->NextSiblingElement(key))
22  {
23  names.push_back(e->Attribute("Expression"));
24  ++n;
25  }
26  }
27 
28  return n;
29  }
30 
31  bool isTerminal(tinyxml2::XMLElement* node)
32  {
33  bool is = true;
34  for(tinyxml2::XMLElement* e = node->FirstChildElement("Node");
35  e != nullptr; e = e->NextSiblingElement("Node")) {
36  is = false;
37  }
38  return is;
39  }
40 
41  unsigned int countIntermediateNodes(tinyxml2::XMLElement* node)
42  {
43 
44  unsigned int count = 0;
45  for(tinyxml2::XMLElement* e = node->FirstChildElement("Node");
46  e != nullptr; e = e->NextSiblingElement("Node")) {
47  count += countIntermediateNodes(e);
48  }
49  return count > 0 ? count + 1 : 0;
50 
51  }
52 
53  unsigned int countTerminalNodes(tinyxml2::XMLElement* node)
54  {
55 
56  unsigned int count = 0;
57  for(tinyxml2::XMLElement* e = node->FirstChildElement("Node");
58  e != nullptr; e = e->NextSiblingElement("Node")) {
59  count += countTerminalNodes(e);
60  }
61  return count > 0 ? count : 1;
62 
63  }
64 
65  void addNode(GBRTree& tree, tinyxml2::XMLElement* node,
66  double scale, bool isRegression, bool useYesNoLeaf,
67  bool adjustboundary, bool isAdaClassifier)
68  {
69 
70  bool nodeIsTerminal = isTerminal(node);
71  if (nodeIsTerminal) {
72  double response = 0.;
73  if (isRegression) {
74  node->QueryDoubleAttribute("res", &response);
75  }
76  else {
77  if (useYesNoLeaf) {
78  node->QueryDoubleAttribute("nType", &response);
79  }
80  else {
81  if (isAdaClassifier) {
82  node->QueryDoubleAttribute("purity", &response);
83  } else {
84  node->QueryDoubleAttribute("res", &response);
85  }
86  }
87  }
88  response *= scale;
89  tree.Responses().push_back(response);
90  }
91  else {
92 
93  int thisidx = tree.CutIndices().size();
94 
95  int selector;
96  float cutval;
97  bool ctype;
98 
99  node->QueryIntAttribute("IVar", &selector);
100  node->QueryFloatAttribute("Cut", &cutval);
101  node->QueryBoolAttribute("cType", &ctype);
102 
103  tree.CutIndices().push_back(static_cast<unsigned char>(selector));
104 
105  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
106  //to reproduce the correct behaviour
107  if (adjustboundary) {
108  cutval = std::nextafter(cutval,std::numeric_limits<float>::lowest());
109  }
110  tree.CutVals().push_back(cutval);
111  tree.LeftIndices().push_back(0);
112  tree.RightIndices().push_back(0);
113 
114  tinyxml2::XMLElement* left = nullptr;
115  tinyxml2::XMLElement* right = nullptr;
116  for(tinyxml2::XMLElement* e = node->FirstChildElement("Node");
117  e != nullptr; e = e->NextSiblingElement("Node")) {
118  if (*(e->Attribute("pos")) == 'l') left = e;
119  else if (*(e->Attribute("pos")) == 'r') right = e;
120  }
121  if (!ctype) {
122  std::swap(left, right);
123  }
124 
125  tree.LeftIndices()[thisidx] = isTerminal(left) ? -tree.Responses().size() : tree.CutIndices().size() ;
126  addNode(tree, left, scale, isRegression, useYesNoLeaf, adjustboundary,isAdaClassifier);
127 
128  tree.RightIndices()[thisidx] = isTerminal(right) ? -tree.Responses().size() : tree.CutIndices().size() ;
129  addNode(tree, right, scale, isRegression, useYesNoLeaf, adjustboundary,isAdaClassifier);
130 
131  }
132 
133  }
134 
135  std::unique_ptr<GBRForest> init(const std::string& weightsFileFullPath,
136  std::vector<std::string>& varNames)
137  {
138 
139  //
140  // Load weights file, for gzipped or raw xml file
141  //
142  tinyxml2::XMLDocument xmlDoc;
143 
144  using namespace reco::details;
145 
146  if (hasEnding(weightsFileFullPath, ".xml")) {
147  xmlDoc.LoadFile(weightsFileFullPath.c_str());
148  } else if (hasEnding(weightsFileFullPath, ".gz") ||
149  hasEnding(weightsFileFullPath, ".gzip")) {
150  char * buffer = readGzipFile(weightsFileFullPath);
151  xmlDoc.Parse(buffer);
152  free(buffer);
153  }
154 
155  tinyxml2::XMLElement* root = xmlDoc.FirstChildElement("MethodSetup");
156  readVariables(root->FirstChildElement("Variables"), "Variable", varNames);
157 
158  // Read in the TMVA general info
159  std::map <std::string, std::string> info;
160  tinyxml2::XMLElement* infoElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("GeneralInfo");
161  if (infoElem == nullptr) {
162  throw cms::Exception("XMLError")
163  << "No GeneralInfo found in " << weightsFileFullPath << " !!\n";
164  }
165  for(tinyxml2::XMLElement* e = infoElem->FirstChildElement("Info");
166  e != nullptr; e = e->NextSiblingElement("Info"))
167  {
168  const char * name;
169  const char * value;
170  e->QueryStringAttribute("name", &name);
171  e->QueryStringAttribute("value", &value);
172  info[name] = value;
173  }
174 
175  // Read in the TMVA options
176  std::map <std::string, std::string> options;
177  tinyxml2::XMLElement* optionsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Options");
178  if (optionsElem == nullptr) {
179  throw cms::Exception("XMLError")
180  << "No Options found in " << weightsFileFullPath << " !!\n";
181  }
182  for(tinyxml2::XMLElement* e = optionsElem->FirstChildElement("Option");
183  e != nullptr; e = e->NextSiblingElement("Option"))
184  {
185  const char * name;
186  e->QueryStringAttribute("name", &name);
187  options[name] = e->GetText();
188  }
189 
190  // Get root version number if available
191  int rootTrainingVersion(0);
192  if (info.find("ROOT Release") != info.end()) {
193  std::string s = info["ROOT Release"];
194  rootTrainingVersion = std::stoi(s.substr(s.find("[")+1,s.find("]")-s.find("[")-1));
195  }
196 
197  // Get the boosting weights
198  std::vector<double> boostWeights;
199  tinyxml2::XMLElement* weightsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Weights");
200  if (weightsElem == nullptr) {
201  throw cms::Exception("XMLError")
202  << "No Weights found in " << weightsFileFullPath << " !!\n";
203  }
204  bool hasTrees = false;
205  for(tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree");
206  e != nullptr; e = e->NextSiblingElement("BinaryTree"))
207  {
208  hasTrees = true;
209  double w;
210  e->QueryDoubleAttribute("boostWeight", &w);
211  boostWeights.push_back(w);
212  }
213  if (!hasTrees) {
214  throw cms::Exception("XMLError")
215  << "No BinaryTrees found in " << weightsFileFullPath << " !!\n";
216  }
217 
218  bool isRegression = info["AnalysisType"] == "Regression";
219 
220  //special handling for non-gradient-boosted (ie ADABoost) classifiers, where tree responses
221  //need to be renormalized after the training for evaluation purposes
222  bool isAdaClassifier = !isRegression && options["BoostType"] != "Grad";
223  bool useYesNoLeaf = isAdaClassifier && options["UseYesNoLeaf"] == "True";
224 
225  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
226  //to reproduce the correct behaviour
227  bool adjustBoundaries = (rootTrainingVersion>=ROOT_VERSION(5,34,20) &&
228  rootTrainingVersion<ROOT_VERSION(6,0,0)) || rootTrainingVersion>=ROOT_VERSION(6,2,0);
229 
230  auto forest = std::make_unique<GBRForest>();
231  forest->SetInitialResponse(isRegression ? boostWeights[0] : 0.);
232 
233  double norm = 0;
234  if (isAdaClassifier) {
235  for (double w : boostWeights) {
236  norm += w;
237  }
238  }
239 
240  forest->Trees().reserve(boostWeights.size());
241  size_t itree = 0;
242  // Loop over tree estimators
243  for(tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree");
244  e != nullptr; e = e->NextSiblingElement("BinaryTree")) {
245  double scale = isAdaClassifier ? boostWeights[itree]/norm : 1.0;
246 
247  tinyxml2::XMLElement* root = e->FirstChildElement("Node");
248  forest->Trees().push_back(GBRTree(countIntermediateNodes(root), countTerminalNodes(root)));
249  auto & tree = forest->Trees().back();
250 
251  addNode(tree, root, scale, isRegression, useYesNoLeaf, adjustBoundaries, isAdaClassifier);
252 
253  //special case, root node is terminal, create fake intermediate node at root
254  if (tree.CutIndices().empty()) {
255  tree.CutIndices().push_back(0);
256  tree.CutVals().push_back(0);
257  tree.LeftIndices().push_back(0);
258  tree.RightIndices().push_back(0);
259  }
260 
261  ++itree;
262  }
263 
264  return forest;
265  }
266 
267 }
268 
269 // Create a GBRForest from an XML weight file
270 std::unique_ptr<const GBRForest>
271 createGBRForest(const std::string &weightsFile)
272 {
273  std::vector<std::string> varNames;
274  return createGBRForest(weightsFile, varNames);
275 }
276 
277 std::unique_ptr<const GBRForest>
278 createGBRForest(const edm::FileInPath &weightsFile)
279 {
280  std::vector<std::string> varNames;
281  return createGBRForest(weightsFile.fullPath(), varNames);
282 }
283 
284 // Overloaded versions which are taking string vectors by reference to store the variable names in
285 std::unique_ptr<const GBRForest>
286 createGBRForest(const std::string &weightsFile, std::vector<std::string> &varNames)
287 {
288  std::unique_ptr<GBRForest> gbrForest;
289 
290  if(weightsFile[0] == '/') {
291  gbrForest = init(weightsFile, varNames);
292  }
293  else {
294  edm::FileInPath weightsFileEdm(weightsFile);
295  gbrForest = init( weightsFileEdm.fullPath(), varNames);
296  }
297  return gbrForest;
298 }
299 
300 std::unique_ptr<const GBRForest>
301 createGBRForest(const edm::FileInPath &weightsFile, std::vector<std::string> &varNames)
302 {
303  return createGBRForest(weightsFile.fullPath(), varNames);
304 }
static const TGPicture * info(bool iBackgroundIsBlack)
std::vector< float > & Responses()
Definition: GBRTree.h:39
const double w
Definition: UKUtility.cc:23
bool hasEnding(std::string const &fullString, std::string const &ending)
int init
Definition: HydjetWrapper.h:67
const std::string names[nVars_]
std::vector< float > & CutVals()
Definition: GBRTree.h:45
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightsFile)
std::vector< int > & LeftIndices()
Definition: GBRTree.h:48
char * readGzipFile(const std::string &weightFile)
char const * varNames[]
std::string fullPath() const
Definition: FileInPath.cc:197
Definition: tree.py:1
std::vector< int > & RightIndices()
Definition: GBRTree.h:51
std::vector< unsigned char > & CutIndices()
Definition: GBRTree.h:42