CMS 3D CMS Logo

GBRForestTools.cc
Go to the documentation of this file.
5 
6 #include "TFile.h"
7 
8 #include <cstdio>
9 #include <cstdlib>
10 #include <RVersion.h>
11 #include <cmath>
12 #include <tinyxml2.h>
13 #include <filesystem>
14 
15 namespace {
16 
17  size_t readVariables(tinyxml2::XMLElement* root, const char* key, std::vector<std::string>& names) {
18  size_t n = 0;
19  names.clear();
20 
21  if (root != nullptr) {
22  for (tinyxml2::XMLElement* e = root->FirstChildElement(key); e != nullptr; e = e->NextSiblingElement(key)) {
23  names.push_back(e->Attribute("Expression"));
24  ++n;
25  }
26  }
27 
28  return n;
29  }
30 
31  bool isTerminal(tinyxml2::XMLElement* node) {
32  bool is = true;
33  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
34  is = false;
35  }
36  return is;
37  }
38 
39  unsigned int countIntermediateNodes(tinyxml2::XMLElement* node) {
40  unsigned int count = 0;
41  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
42  count += countIntermediateNodes(e);
43  }
44  return count > 0 ? count + 1 : 0;
45  }
46 
47  unsigned int countTerminalNodes(tinyxml2::XMLElement* node) {
48  unsigned int count = 0;
49  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
50  count += countTerminalNodes(e);
51  }
52  return count > 0 ? count : 1;
53  }
54 
55  void addNode(GBRTree& tree,
56  tinyxml2::XMLElement* node,
57  double scale,
58  bool isRegression,
59  bool useYesNoLeaf,
60  bool adjustboundary,
61  bool isAdaClassifier) {
62  bool nodeIsTerminal = isTerminal(node);
63  if (nodeIsTerminal) {
64  double response = 0.;
65  if (isRegression) {
66  node->QueryDoubleAttribute("res", &response);
67  } else {
68  if (useYesNoLeaf) {
69  node->QueryDoubleAttribute("nType", &response);
70  } else {
71  if (isAdaClassifier) {
72  node->QueryDoubleAttribute("purity", &response);
73  } else {
74  node->QueryDoubleAttribute("res", &response);
75  }
76  }
77  }
78  response *= scale;
79  tree.Responses().push_back(response);
80  } else {
81  int thisidx = tree.CutIndices().size();
82 
83  int selector;
84  float cutval;
85  bool ctype;
86 
87  node->QueryIntAttribute("IVar", &selector);
88  node->QueryFloatAttribute("Cut", &cutval);
89  node->QueryBoolAttribute("cType", &ctype);
90 
91  tree.CutIndices().push_back(static_cast<unsigned char>(selector));
92 
93  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
94  //to reproduce the correct behaviour
95  if (adjustboundary) {
96  cutval = std::nextafter(cutval, std::numeric_limits<float>::lowest());
97  }
98  tree.CutVals().push_back(cutval);
99  tree.LeftIndices().push_back(0);
100  tree.RightIndices().push_back(0);
101 
102  tinyxml2::XMLElement* left = nullptr;
103  tinyxml2::XMLElement* right = nullptr;
104  for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
105  if (*(e->Attribute("pos")) == 'l')
106  left = e;
107  else if (*(e->Attribute("pos")) == 'r')
108  right = e;
109  }
110  if (!ctype) {
111  std::swap(left, right);
112  }
113 
114  tree.LeftIndices()[thisidx] = isTerminal(left) ? -tree.Responses().size() : tree.CutIndices().size();
115  addNode(tree, left, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
116 
117  tree.RightIndices()[thisidx] = isTerminal(right) ? -tree.Responses().size() : tree.CutIndices().size();
118  addNode(tree, right, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
119  }
120  }
121 
122  std::unique_ptr<GBRForest> init(const std::string& weightsFileFullPath, std::vector<std::string>& varNames) {
123  //
124  // Load weights file, for ROOT file
125  //
126  if (reco::details::hasEnding(weightsFileFullPath, ".root")) {
127  TFile gbrForestFile(weightsFileFullPath.c_str());
128  std::unique_ptr<GBRForest> up(gbrForestFile.Get<GBRForest>("gbrForest"));
129  gbrForestFile.Close("nodelete");
130  return up;
131  }
132 
133  //
134  // Load weights file, for gzipped or raw xml file
135  //
136  tinyxml2::XMLDocument xmlDoc;
137 
138  using namespace reco::details;
139 
140  if (hasEnding(weightsFileFullPath, ".xml")) {
141  xmlDoc.LoadFile(weightsFileFullPath.c_str());
142  } else if (hasEnding(weightsFileFullPath, ".gz") || hasEnding(weightsFileFullPath, ".gzip")) {
143  char* buffer = readGzipFile(weightsFileFullPath);
144  xmlDoc.Parse(buffer);
145  free(buffer);
146  }
147 
148  tinyxml2::XMLElement* root = xmlDoc.FirstChildElement("MethodSetup");
149  readVariables(root->FirstChildElement("Variables"), "Variable", varNames);
150 
151  // Read in the TMVA general info
152  std::map<std::string, std::string> info;
153  tinyxml2::XMLElement* infoElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("GeneralInfo");
154  if (infoElem == nullptr) {
155  throw cms::Exception("XMLError") << "No GeneralInfo found in " << weightsFileFullPath << " !!\n";
156  }
157  for (tinyxml2::XMLElement* e = infoElem->FirstChildElement("Info"); e != nullptr;
158  e = e->NextSiblingElement("Info")) {
159  const char* name;
160  const char* value;
161  e->QueryStringAttribute("name", &name);
162  e->QueryStringAttribute("value", &value);
163  info[name] = value;
164  }
165 
166  // Read in the TMVA options
167  std::map<std::string, std::string> options;
168  tinyxml2::XMLElement* optionsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Options");
169  if (optionsElem == nullptr) {
170  throw cms::Exception("XMLError") << "No Options found in " << weightsFileFullPath << " !!\n";
171  }
172  for (tinyxml2::XMLElement* e = optionsElem->FirstChildElement("Option"); e != nullptr;
173  e = e->NextSiblingElement("Option")) {
174  const char* name;
175  e->QueryStringAttribute("name", &name);
176  options[name] = e->GetText();
177  }
178 
179  // Get root version number if available
180  int rootTrainingVersion(0);
181  if (info.find("ROOT Release") != info.end()) {
182  std::string s = info["ROOT Release"];
183  rootTrainingVersion = std::stoi(s.substr(s.find('[') + 1, s.find(']') - s.find('[') - 1));
184  }
185 
186  // Get the boosting weights
187  std::vector<double> boostWeights;
188  tinyxml2::XMLElement* weightsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Weights");
189  if (weightsElem == nullptr) {
190  throw cms::Exception("XMLError") << "No Weights found in " << weightsFileFullPath << " !!\n";
191  }
192  bool hasTrees = false;
193  for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
194  e = e->NextSiblingElement("BinaryTree")) {
195  hasTrees = true;
196  double w;
197  e->QueryDoubleAttribute("boostWeight", &w);
198  boostWeights.push_back(w);
199  }
200  if (!hasTrees) {
201  throw cms::Exception("XMLError") << "No BinaryTrees found in " << weightsFileFullPath << " !!\n";
202  }
203 
204  bool isRegression = info["AnalysisType"] == "Regression";
205 
206  //special handling for non-gradient-boosted (ie ADABoost) classifiers, where tree responses
207  //need to be renormalized after the training for evaluation purposes
208  bool isAdaClassifier = !isRegression && options["BoostType"] != "Grad";
209  bool useYesNoLeaf = isAdaClassifier && options["UseYesNoLeaf"] == "True";
210 
211  //newer tmva versions use >= instead of > in decision tree splits, so adjust cut value
212  //to reproduce the correct behaviour
213  bool adjustBoundaries =
214  (rootTrainingVersion >= ROOT_VERSION(5, 34, 20) && rootTrainingVersion < ROOT_VERSION(6, 0, 0)) ||
215  rootTrainingVersion >= ROOT_VERSION(6, 2, 0);
216 
217  auto forest = std::make_unique<GBRForest>();
218  forest->SetInitialResponse(isRegression ? boostWeights[0] : 0.);
219 
220  double norm = 0;
221  if (isAdaClassifier) {
222  for (double w : boostWeights) {
223  norm += w;
224  }
225  }
226 
227  forest->Trees().reserve(boostWeights.size());
228  size_t itree = 0;
229  // Loop over tree estimators
230  for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
231  e = e->NextSiblingElement("BinaryTree")) {
232  double scale = isAdaClassifier ? boostWeights[itree] / norm : 1.0;
233 
234  tinyxml2::XMLElement* root = e->FirstChildElement("Node");
235  forest->Trees().push_back(GBRTree(countIntermediateNodes(root), countTerminalNodes(root)));
236  auto& tree = forest->Trees().back();
237 
238  addNode(tree, root, scale, isRegression, useYesNoLeaf, adjustBoundaries, isAdaClassifier);
239 
240  //special case, root node is terminal, create fake intermediate node at root
241  if (tree.CutIndices().empty()) {
242  tree.CutIndices().push_back(0);
243  tree.CutVals().push_back(0);
244  tree.LeftIndices().push_back(0);
245  tree.RightIndices().push_back(0);
246  }
247 
248  ++itree;
249  }
250 
251  return forest;
252  }
253 
254 } // namespace
255 
256 // Create a GBRForest from an XML weight file
257 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile) {
258  std::vector<std::string> varNames;
260 }
261 
262 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile) {
263  std::vector<std::string> varNames;
264  return createGBRForest(weightsFile.fullPath(), varNames);
265 }
266 
267 // Overloaded versions which are taking string vectors by reference to store the variable names in
268 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile, std::vector<std::string>& varNames) {
269  std::unique_ptr<GBRForest> gbrForest;
270 
271  if (weightsFile[0] == '/') {
272  gbrForest = init(weightsFile, varNames);
273  } else {
274  edm::FileInPath weightsFileEdm(weightsFile);
275  gbrForest = init(weightsFileEdm.fullPath(), varNames);
276  }
277  return gbrForest;
278 }
279 
280 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile,
281  std::vector<std::string>& varNames) {
282  return createGBRForest(weightsFile.fullPath(), varNames);
283 }
init
int init
Definition: HydjetWrapper.h:64
dqmiodumpmetadata.n
n
Definition: dqmiodumpmetadata.py:28
L1EGammaCrystalsEmulatorProducer_cfi.scale
scale
Definition: L1EGammaCrystalsEmulatorProducer_cfi.py:10
reco::details::readGzipFile
char * readGzipFile(const std::string &weightFile)
Definition: TMVAZipReader.cc:19
GBRForestTools.h
tree
Definition: tree.py:1
GBRForest
Definition: GBRForest.h:24
info
static const TGPicture * info(bool iBackgroundIsBlack)
Definition: FWCollectionSummaryWidget.cc:153
pfClustersFromHGC3DClusters_cfi.weightsFile
weightsFile
Definition: pfClustersFromHGC3DClusters_cfi.py:19
edmScanValgrind.buffer
buffer
Definition: edmScanValgrind.py:171
options
Definition: options.py:1
edm::FileInPath
Definition: FileInPath.h:61
alignCSCRings.s
s
Definition: alignCSCRings.py:92
std::swap
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
Definition: DataFrameContainer.h:209
names
const std::string names[nVars_]
Definition: PhotonIDValueMapProducer.cc:124
w
const double w
Definition: UKUtility.cc:23
submitPVResolutionJobs.count
count
Definition: submitPVResolutionJobs.py:352
reco::details
Definition: TMVAZipReader.h:30
FileInPath.h
value
Definition: value.py:1
root
Definition: RooFitFunction.h:10
AlCaHLTBitMon_QueryRunRegistry.string
string string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
GBRTree
Definition: GBRTree.h:29
relativeConstraints.value
value
Definition: relativeConstraints.py:53
Exception
Definition: hltDiff.cc:245
AlcaSiPixelAliHarvester0T_cff.options
options
Definition: AlcaSiPixelAliHarvester0T_cff.py:42
Skims_PA_cff.name
name
Definition: Skims_PA_cff.py:17
TMVAZipReader.h
Exception.h
varNames
constexpr const char * varNames[]
Definition: PulseShapeFitOOTPileupCorrection.cc:110
crabWrapper.key
key
Definition: crabWrapper.py:19
up
Definition: BitonicSort.h:7
reco::details::hasEnding
bool hasEnding(std::string const &fullString, std::string const &ending)
Definition: TMVAZipReader.cc:11
edm::FileInPath::fullPath
std::string fullPath() const
Definition: FileInPath.cc:161
createGBRForest
std::unique_ptr< const GBRForest > createGBRForest(const std::string &weightsFile)
Definition: GBRForestTools.cc:257
MillePedeFileConverter_cfg.e
e
Definition: MillePedeFileConverter_cfg.py:37