17 size_t readVariables(tinyxml2::XMLElement*
root,
const char*
key, std::vector<std::string>&
names) {
21 if (
root !=
nullptr) {
22 for (tinyxml2::XMLElement*
e =
root->FirstChildElement(
key);
e !=
nullptr;
e =
e->NextSiblingElement(
key)) {
23 names.push_back(
e->Attribute(
"Expression"));
31 bool isTerminal(tinyxml2::XMLElement* node) {
33 for (tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
39 unsigned int countIntermediateNodes(tinyxml2::XMLElement* node) {
40 unsigned int count = 0;
41 for (tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
42 count += countIntermediateNodes(
e);
47 unsigned int countTerminalNodes(tinyxml2::XMLElement* node) {
48 unsigned int count = 0;
49 for (tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
50 count += countTerminalNodes(
e);
56 tinyxml2::XMLElement* node,
61 bool isAdaClassifier) {
62 bool nodeIsTerminal = isTerminal(node);
66 node->QueryDoubleAttribute(
"res", &response);
69 node->QueryDoubleAttribute(
"nType", &response);
71 if (isAdaClassifier) {
72 node->QueryDoubleAttribute(
"purity", &response);
74 node->QueryDoubleAttribute(
"res", &response);
79 tree.Responses().push_back(response);
81 int thisidx =
tree.CutIndices().size();
87 node->QueryIntAttribute(
"IVar", &selector);
88 node->QueryFloatAttribute(
"Cut", &cutval);
89 node->QueryBoolAttribute(
"cType", &ctype);
91 tree.CutIndices().push_back(static_cast<unsigned char>(selector));
96 cutval = std::nextafter(cutval, std::numeric_limits<float>::lowest());
98 tree.CutVals().push_back(cutval);
99 tree.LeftIndices().push_back(0);
100 tree.RightIndices().push_back(0);
102 tinyxml2::XMLElement* left =
nullptr;
103 tinyxml2::XMLElement* right =
nullptr;
104 for (tinyxml2::XMLElement*
e = node->FirstChildElement(
"Node");
e !=
nullptr;
e =
e->NextSiblingElement(
"Node")) {
105 if (*(
e->Attribute(
"pos")) ==
'l')
107 else if (*(
e->Attribute(
"pos")) ==
'r')
114 tree.LeftIndices()[thisidx] = isTerminal(left) ? -
tree.Responses().size() :
tree.CutIndices().size();
115 addNode(
tree, left,
scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
117 tree.RightIndices()[thisidx] = isTerminal(right) ? -
tree.Responses().size() :
tree.CutIndices().size();
118 addNode(
tree, right,
scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
127 TFile gbrForestFile(weightsFileFullPath.c_str());
128 std::unique_ptr<GBRForest>
up(gbrForestFile.Get<
GBRForest>(
"gbrForest"));
129 std::unique_ptr<std::vector<std::string>>
vars(gbrForestFile.Get<std::vector<std::string>>(
"variableNames"));
130 gbrForestFile.Close(
"nodelete");
140 tinyxml2::XMLDocument xmlDoc;
144 if (
hasEnding(weightsFileFullPath,
".xml")) {
145 xmlDoc.LoadFile(weightsFileFullPath.c_str());
146 }
else if (
hasEnding(weightsFileFullPath,
".gz") ||
hasEnding(weightsFileFullPath,
".gzip")) {
152 tinyxml2::XMLElement*
root = xmlDoc.FirstChildElement(
"MethodSetup");
153 readVariables(
root->FirstChildElement(
"Variables"),
"Variable",
varNames);
156 std::map<std::string, std::string>
info;
157 tinyxml2::XMLElement* infoElem = xmlDoc.FirstChildElement(
"MethodSetup")->FirstChildElement(
"GeneralInfo");
158 if (infoElem ==
nullptr) {
159 throw cms::Exception(
"XMLError") <<
"No GeneralInfo found in " << weightsFileFullPath <<
" !!\n";
161 for (tinyxml2::XMLElement*
e = infoElem->FirstChildElement(
"Info");
e !=
nullptr;
162 e =
e->NextSiblingElement(
"Info")) {
165 if (tinyxml2::XML_SUCCESS !=
e->QueryStringAttribute(
"name", &
name)) {
166 throw cms::Exception(
"XMLERROR") <<
"no 'name' attribute found in 'Info' element in " << weightsFileFullPath;
168 if (tinyxml2::XML_SUCCESS !=
e->QueryStringAttribute(
"value", &
value)) {
169 throw cms::Exception(
"XMLERROR") <<
"no 'value' attribute found in 'Info' element in " << weightsFileFullPath;
175 std::map<std::string, std::string>
options;
176 tinyxml2::XMLElement* optionsElem = xmlDoc.FirstChildElement(
"MethodSetup")->FirstChildElement(
"Options");
177 if (optionsElem ==
nullptr) {
178 throw cms::Exception(
"XMLError") <<
"No Options found in " << weightsFileFullPath <<
" !!\n";
180 for (tinyxml2::XMLElement*
e = optionsElem->FirstChildElement(
"Option");
e !=
nullptr;
181 e =
e->NextSiblingElement(
"Option")) {
183 if (tinyxml2::XML_SUCCESS !=
e->QueryStringAttribute(
"name", &
name)) {
184 throw cms::Exception(
"XMLERROR") <<
"no 'name' attribute found in 'Option' element in " << weightsFileFullPath;
190 int rootTrainingVersion(0);
191 if (
info.find(
"ROOT Release") !=
info.end()) {
193 rootTrainingVersion = std::stoi(
s.substr(
s.find(
'[') + 1,
s.find(
']') -
s.find(
'[') - 1));
197 std::vector<double> boostWeights;
198 tinyxml2::XMLElement* weightsElem = xmlDoc.FirstChildElement(
"MethodSetup")->FirstChildElement(
"Weights");
199 if (weightsElem ==
nullptr) {
200 throw cms::Exception(
"XMLError") <<
"No Weights found in " << weightsFileFullPath <<
" !!\n";
202 bool hasTrees =
false;
203 for (tinyxml2::XMLElement*
e = weightsElem->FirstChildElement(
"BinaryTree");
e !=
nullptr;
204 e =
e->NextSiblingElement(
"BinaryTree")) {
207 if (tinyxml2::XML_SUCCESS !=
e->QueryDoubleAttribute(
"boostWeight", &
w)) {
208 throw cms::Exception(
"XMLERROR") <<
"problem with 'boostWeight' attribute found in 'BinaryTree' element in " 209 << weightsFileFullPath;
211 boostWeights.push_back(
w);
214 throw cms::Exception(
"XMLError") <<
"No BinaryTrees found in " << weightsFileFullPath <<
" !!\n";
217 bool isRegression =
info[
"AnalysisType"] ==
"Regression";
221 bool isAdaClassifier = !isRegression &&
options[
"BoostType"] !=
"Grad";
222 bool useYesNoLeaf = isAdaClassifier &&
options[
"UseYesNoLeaf"] ==
"True";
226 bool adjustBoundaries =
227 (rootTrainingVersion >= ROOT_VERSION(5, 34, 20) && rootTrainingVersion < ROOT_VERSION(6, 0, 0)) ||
228 rootTrainingVersion >= ROOT_VERSION(6, 2, 0);
230 auto forest = std::make_unique<GBRForest>();
231 forest->SetInitialResponse(isRegression ? boostWeights[0] : 0.);
234 if (isAdaClassifier) {
235 for (
double w : boostWeights) {
240 forest->Trees().reserve(boostWeights.size());
243 for (tinyxml2::XMLElement*
e = weightsElem->FirstChildElement(
"BinaryTree");
e !=
nullptr;
244 e =
e->NextSiblingElement(
"BinaryTree")) {
245 double scale = isAdaClassifier ? boostWeights[itree] / norm : 1.0;
247 tinyxml2::XMLElement*
root =
e->FirstChildElement(
"Node");
248 forest->Trees().push_back(
GBRTree(countIntermediateNodes(
root), countTerminalNodes(
root)));
249 auto&
tree = forest->Trees().back();
251 addNode(
tree,
root,
scale, isRegression, useYesNoLeaf, adjustBoundaries, isAdaClassifier);
254 if (
tree.CutIndices().empty()) {
255 tree.CutIndices().push_back(0);
256 tree.CutVals().push_back(0);
257 tree.LeftIndices().push_back(0);
258 tree.RightIndices().push_back(0);
282 std::unique_ptr<GBRForest> gbrForest;
294 std::vector<std::string>&
varNames) {
bool hasEnding(std::string const &fullString, std::string const &ending)
std::string fullPath() const
constexpr char const * varNames[]
const std::string names[nVars_]
void swap(edm::DataFrameContainer &lhs, edm::DataFrameContainer &rhs)
char * readGzipFile(const std::string &weightFile)