16 size_t readVariables(tinyxml2::XMLElement*
root,
const char*
key, std::vector<std::string>&
names) {
20 if (root !=
nullptr) {
21 for (tinyxml2::XMLElement*
e = root->FirstChildElement(key);
e !=
nullptr;
e =
e->NextSiblingElement(key)) {
22 names.push_back(
e->Attribute(
"Expression"));
30 bool isTerminal(tinyxml2::XMLElement* node) {
32 for (tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
38 unsigned int countIntermediateNodes(tinyxml2::XMLElement* node) {
39 unsigned int count = 0;
40 for (tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
41 count += countIntermediateNodes(
e);
43 return count > 0 ? count + 1 : 0;
46 unsigned int countTerminalNodes(tinyxml2::XMLElement* node) {
47 unsigned int count = 0;
48 for (tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
49 count += countTerminalNodes(
e);
51 return count > 0 ? count : 1;
55 tinyxml2::XMLElement* node,
60 bool isAdaClassifier) {
61 bool nodeIsTerminal = isTerminal(node);
65 node->QueryDoubleAttribute(
"res", &response);
68 node->QueryDoubleAttribute(
"nType", &response);
70 if (isAdaClassifier) {
71 node->QueryDoubleAttribute(
"purity", &response);
73 node->QueryDoubleAttribute(
"res", &response);
86 node->QueryIntAttribute(
"IVar", &selector);
87 node->QueryFloatAttribute(
"Cut", &cutval);
88 node->QueryBoolAttribute(
"cType", &ctype);
90 tree.
CutIndices().push_back(static_cast<unsigned char>(selector));
95 cutval = std::nextafter(cutval, std::numeric_limits<float>::lowest());
97 tree.
CutVals().push_back(cutval);
101 tinyxml2::XMLElement* left =
nullptr;
102 tinyxml2::XMLElement* right =
nullptr;
103 for (tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
104 if (*(
e->Attribute(
"pos")) ==
'l')
106 else if (*(
e->Attribute(
"pos")) ==
'r')
114 addNode(tree, left, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
117 addNode(tree, right, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
126 TFile gbrForestFile(weightsFileFullPath.c_str());
127 std::unique_ptr<GBRForest>
up(reinterpret_cast<GBRForest*>(gbrForestFile.Get(
"gbrForest")));
128 gbrForestFile.Close(
"nodelete");
135 tinyxml2::XMLDocument xmlDoc;
139 if (
hasEnding(weightsFileFullPath,
".xml")) {
140 xmlDoc.LoadFile(weightsFileFullPath.c_str());
141 }
else if (
hasEnding(weightsFileFullPath,
".gz") ||
hasEnding(weightsFileFullPath,
".gzip")) {
143 xmlDoc.Parse(buffer);
147 tinyxml2::XMLElement* root = xmlDoc.FirstChildElement(
"MethodSetup");
148 readVariables(root->FirstChildElement(
"Variables"),
"Variable",
varNames);
151 std::map<std::string, std::string>
info;
152 tinyxml2::XMLElement* infoElem = xmlDoc.FirstChildElement(
"MethodSetup")->FirstChildElement(
"GeneralInfo");
153 if (infoElem ==
nullptr) {
154 throw cms::Exception(
"XMLError") <<
"No GeneralInfo found in " << weightsFileFullPath <<
" !!\n";
156 for (tinyxml2::XMLElement*
e = infoElem->FirstChildElement(
"Info");
e !=
nullptr;
157 e =
e->NextSiblingElement(
"Info")) {
160 e->QueryStringAttribute(
"name", &name);
161 e->QueryStringAttribute(
"value", &value);
166 std::map<std::string, std::string>
options;
167 tinyxml2::XMLElement* optionsElem = xmlDoc.FirstChildElement(
"MethodSetup")->FirstChildElement(
"Options");
168 if (optionsElem ==
nullptr) {
169 throw cms::Exception(
"XMLError") <<
"No Options found in " << weightsFileFullPath <<
" !!\n";
171 for (tinyxml2::XMLElement*
e = optionsElem->FirstChildElement(
"Option");
e !=
nullptr;
172 e =
e->NextSiblingElement(
"Option")) {
174 e->QueryStringAttribute(
"name", &name);
175 options[
name] =
e->GetText();
179 int rootTrainingVersion(0);
180 if (info.find(
"ROOT Release") != info.end()) {
182 rootTrainingVersion = std::stoi(s.substr(s.find(
"[") + 1, s.find(
"]") - s.find(
"[") - 1));
186 std::vector<double> boostWeights;
187 tinyxml2::XMLElement* weightsElem = xmlDoc.FirstChildElement(
"MethodSetup")->FirstChildElement(
"Weights");
188 if (weightsElem ==
nullptr) {
189 throw cms::Exception(
"XMLError") <<
"No Weights found in " << weightsFileFullPath <<
" !!\n";
191 bool hasTrees =
false;
192 for (tinyxml2::XMLElement*
e = weightsElem->FirstChildElement(
"BinaryTree");
e !=
nullptr;
193 e =
e->NextSiblingElement(
"BinaryTree")) {
196 e->QueryDoubleAttribute(
"boostWeight", &w);
197 boostWeights.push_back(w);
200 throw cms::Exception(
"XMLError") <<
"No BinaryTrees found in " << weightsFileFullPath <<
" !!\n";
203 bool isRegression = info[
"AnalysisType"] ==
"Regression";
207 bool isAdaClassifier = !isRegression && options[
"BoostType"] !=
"Grad";
208 bool useYesNoLeaf = isAdaClassifier && options[
"UseYesNoLeaf"] ==
"True";
212 bool adjustBoundaries =
213 (rootTrainingVersion >= ROOT_VERSION(5, 34, 20) && rootTrainingVersion < ROOT_VERSION(6, 0, 0)) ||
214 rootTrainingVersion >= ROOT_VERSION(6, 2, 0);
216 auto forest = std::make_unique<GBRForest>();
217 forest->SetInitialResponse(isRegression ? boostWeights[0] : 0.);
220 if (isAdaClassifier) {
221 for (
double w : boostWeights) {
226 forest->Trees().reserve(boostWeights.size());
229 for (tinyxml2::XMLElement*
e = weightsElem->FirstChildElement(
"BinaryTree");
e !=
nullptr;
230 e =
e->NextSiblingElement(
"BinaryTree")) {
231 double scale = isAdaClassifier ? boostWeights[itree] / norm : 1.0;
233 tinyxml2::XMLElement* root =
e->FirstChildElement(
"Node");
234 forest->Trees().push_back(
GBRTree(countIntermediateNodes(root), countTerminalNodes(root)));
235 auto& tree = forest->Trees().back();
237 addNode(tree, root, scale, isRegression, useYesNoLeaf, adjustBoundaries, isAdaClassifier);
268 std::unique_ptr<GBRForest> gbrForest;
270 if (weightsFile[0] ==
'/') {
271 gbrForest =
init(weightsFile, varNames);
280 std::vector<std::string>& varNames) {
std::vector< float > & Responses()
bool hasEnding(std::string const &fullString, std::string const &ending)
const std::string names[nVars_]
std::vector< float > & CutVals()
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
std::vector< int > & LeftIndices()
char * readGzipFile(const std::string &weightFile)
std::string fullPath() const
std::vector< int > & RightIndices()
std::vector< unsigned char > & CutIndices()