14 size_t readVariables(tinyxml2::XMLElement*
root,
const char *
key, std::vector<std::string>&
names)
19 if (root !=
nullptr) {
20 for(tinyxml2::XMLElement*
e = root->FirstChildElement(key);
21 e !=
nullptr;
e =
e->NextSiblingElement(key))
23 names.push_back(
e->Attribute(
"Expression"));
31 bool isTerminal(tinyxml2::XMLElement* node)
34 for(tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
35 e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
41 unsigned int countIntermediateNodes(tinyxml2::XMLElement* node)
44 unsigned int count = 0;
45 for(tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
46 e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
47 count += countIntermediateNodes(
e);
49 return count > 0 ? count + 1 : 0;
53 unsigned int countTerminalNodes(tinyxml2::XMLElement* node)
56 unsigned int count = 0;
57 for(tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
58 e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
59 count += countTerminalNodes(
e);
61 return count > 0 ? count : 1;
65 void addNode(
GBRTree&
tree, tinyxml2::XMLElement* node,
66 double scale,
bool isRegression,
bool useYesNoLeaf,
67 bool adjustboundary,
bool isAdaClassifier)
70 bool nodeIsTerminal = isTerminal(node);
74 node->QueryDoubleAttribute(
"res", &response);
78 node->QueryDoubleAttribute(
"nType", &response);
81 if (isAdaClassifier) {
82 node->QueryDoubleAttribute(
"purity", &response);
84 node->QueryDoubleAttribute(
"res", &response);
99 node->QueryIntAttribute(
"IVar", &selector);
100 node->QueryFloatAttribute(
"Cut", &cutval);
101 node->QueryBoolAttribute(
"cType", &ctype);
103 tree.
CutIndices().push_back(static_cast<unsigned char>(selector));
107 if (adjustboundary) {
108 cutval = std::nextafter(cutval,std::numeric_limits<float>::lowest());
110 tree.
CutVals().push_back(cutval);
114 tinyxml2::XMLElement* left =
nullptr;
115 tinyxml2::XMLElement* right =
nullptr;
116 for(tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
117 e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
118 if (*(
e->Attribute(
"pos")) ==
'l') left =
e;
119 else if (*(
e->Attribute(
"pos")) ==
'r') right =
e;
126 addNode(tree, left, scale, isRegression, useYesNoLeaf, adjustboundary,isAdaClassifier);
129 addNode(tree, right, scale, isRegression, useYesNoLeaf, adjustboundary,isAdaClassifier);
135 std::unique_ptr<GBRForest>
init(
const std::string& weightsFileFullPath,
142 tinyxml2::XMLDocument xmlDoc;
146 if (
hasEnding(weightsFileFullPath,
".xml")) {
147 xmlDoc.LoadFile(weightsFileFullPath.c_str());
148 }
else if (
hasEnding(weightsFileFullPath,
".gz") ||
149 hasEnding(weightsFileFullPath,
".gzip")) {
151 xmlDoc.Parse(buffer);
155 tinyxml2::XMLElement* root = xmlDoc.FirstChildElement(
"MethodSetup");
156 readVariables(root->FirstChildElement(
"Variables"),
"Variable",
varNames);
159 std::map <std::string, std::string>
info;
160 tinyxml2::XMLElement* infoElem = xmlDoc.FirstChildElement(
"MethodSetup")->FirstChildElement(
"GeneralInfo");
161 if (infoElem ==
nullptr) {
163 <<
"No GeneralInfo found in " << weightsFileFullPath <<
" !!\n";
165 for(tinyxml2::XMLElement*
e = infoElem->FirstChildElement(
"Info");
166 e !=
nullptr;
e =
e->NextSiblingElement(
"Info"))
170 e->QueryStringAttribute(
"name", &name);
171 e->QueryStringAttribute(
"value", &value);
176 std::map <std::string, std::string>
options;
177 tinyxml2::XMLElement* optionsElem = xmlDoc.FirstChildElement(
"MethodSetup")->FirstChildElement(
"Options");
178 if (optionsElem ==
nullptr) {
180 <<
"No Options found in " << weightsFileFullPath <<
" !!\n";
182 for(tinyxml2::XMLElement*
e = optionsElem->FirstChildElement(
"Option");
183 e !=
nullptr;
e =
e->NextSiblingElement(
"Option"))
186 e->QueryStringAttribute(
"name", &name);
187 options[
name] =
e->GetText();
191 int rootTrainingVersion(0);
192 if (info.find(
"ROOT Release") != info.end()) {
194 rootTrainingVersion = std::stoi(s.substr(s.find(
"[")+1,s.find(
"]")-s.find(
"[")-1));
198 std::vector<double> boostWeights;
199 tinyxml2::XMLElement* weightsElem = xmlDoc.FirstChildElement(
"MethodSetup")->FirstChildElement(
"Weights");
200 if (weightsElem ==
nullptr) {
202 <<
"No Weights found in " << weightsFileFullPath <<
" !!\n";
204 bool hasTrees =
false;
205 for(tinyxml2::XMLElement*
e = weightsElem->FirstChildElement(
"BinaryTree");
206 e !=
nullptr;
e =
e->NextSiblingElement(
"BinaryTree"))
210 e->QueryDoubleAttribute(
"boostWeight", &w);
211 boostWeights.push_back(w);
215 <<
"No BinaryTrees found in " << weightsFileFullPath <<
" !!\n";
218 bool isRegression = info[
"AnalysisType"] ==
"Regression";
222 bool isAdaClassifier = !isRegression && options[
"BoostType"] !=
"Grad";
223 bool useYesNoLeaf = isAdaClassifier && options[
"UseYesNoLeaf"] ==
"True";
227 bool adjustBoundaries = (rootTrainingVersion>=ROOT_VERSION(5,34,20) &&
228 rootTrainingVersion<ROOT_VERSION(6,0,0)) || rootTrainingVersion>=ROOT_VERSION(6,2,0);
230 auto forest = std::make_unique<GBRForest>();
231 forest->SetInitialResponse(isRegression ? boostWeights[0] : 0.);
234 if (isAdaClassifier) {
235 for (
double w : boostWeights) {
240 forest->Trees().reserve(boostWeights.size());
243 for(tinyxml2::XMLElement*
e = weightsElem->FirstChildElement(
"BinaryTree");
244 e !=
nullptr;
e =
e->NextSiblingElement(
"BinaryTree")) {
245 double scale = isAdaClassifier ? boostWeights[itree]/norm : 1.0;
247 tinyxml2::XMLElement* root =
e->FirstChildElement(
"Node");
248 forest->Trees().push_back(
GBRTree(countIntermediateNodes(root), countTerminalNodes(root)));
249 auto & tree = forest->Trees().back();
251 addNode(tree, root, scale, isRegression, useYesNoLeaf, adjustBoundaries, isAdaClassifier);
270 std::unique_ptr<const GBRForest>
277 std::unique_ptr<const GBRForest>
285 std::unique_ptr<const GBRForest>
288 std::unique_ptr<GBRForest> gbrForest;
290 if(weightsFile[0] ==
'/') {
291 gbrForest =
init(weightsFile, varNames);
300 std::unique_ptr<const GBRForest>
std::vector< float > & Responses()
bool hasEnding(std::string const &fullString, std::string const &ending)
const std::string names[nVars_]
std::vector< float > & CutVals()
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
std::vector< int > & LeftIndices()
char * readGzipFile(const std::string &weightFile)
std::string fullPath() const
std::vector< int > & RightIndices()
std::vector< unsigned char > & CutIndices()