00001 #include <iostream>
00002 #include <iomanip>
00003 #include <cstring>
00004 #include <vector>
00005 #include <cmath>
00006
00007 #include <TMatrixD.h>
00008 #include <TVectorD.h>
00009 #include <TDecompSVD.h>
00010
00011 #include "FWCore/Utilities/interface/Exception.h"
00012
00013 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00014 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00015 #include "PhysicsTools/MVATrainer/interface/XMLUniStr.h"
00016 #include "PhysicsTools/MVATrainer/interface/LeastSquares.h"
00017
00018 XERCES_CPP_NAMESPACE_USE
00019
00020 namespace PhysicsTools {
00021
00022 LeastSquares::LeastSquares(unsigned int n) :
00023 coeffs(n + 2), covar(n + 1), corr(n + 1), rotation(n, n),
00024 weights(n + 1), variance(n + 1), trace(n), n(n)
00025 {
00026 }
00027
00028 LeastSquares::~LeastSquares()
00029 {
00030 }
00031
00032 void LeastSquares::add(const std::vector<double> &values,
00033 double dest, double weight)
00034 {
00035 if (values.size() != n)
00036 throw cms::Exception("LeastSquares")
00037 << "add(): invalid array size!" << std::endl;
00038
00039 for(unsigned int i = 0; i < n; i++) {
00040 for(unsigned int j = 0; j < n; j++)
00041 coeffs(i, j) += values[i] * values[j] * weight;
00042
00043 coeffs(n, i) += values[i] * dest * weight;
00044 coeffs(i, n) += values[i] * dest * weight;
00045 coeffs(n + 1, i) += values[i] * weight;
00046 coeffs(i, n + 1) += values[i] * weight;
00047 }
00048
00049 coeffs(n, n) += dest * dest * weight;
00050 coeffs(n + 1, n) += dest * weight;
00051 coeffs(n, n + 1) += dest * weight;
00052 coeffs(n + 1, n + 1) += weight;
00053 }
00054
00055 void LeastSquares::add(const LeastSquares &other, double weight)
00056 {
00057 if (other.getSize() != n)
00058 throw cms::Exception("LeastSquares")
00059 << "add(): invalid array size!" << std::endl;
00060
00061 coeffs += weight * other.coeffs;
00062 }
00063
00064 TVectorD LeastSquares::solveFisher(const TMatrixDSym &coeffs)
00065 {
00066 unsigned int n = coeffs.GetNrows() - 2;
00067
00068 TMatrixDSym tmp;
00069 coeffs.GetSub(0, n, tmp);
00070 tmp[n] = TVectorD(n + 1, coeffs[n + 1].GetPtr());
00071 tmp(n, n) = coeffs(n + 1, n + 1);
00072
00073 TDecompSVD decCoeffs(tmp);
00074 bool ok;
00075 return decCoeffs.Solve(TVectorD(n + 1, coeffs[n].GetPtr()), ok);
00076 }
00077
00078 TMatrixD LeastSquares::solveRotation(const TMatrixDSym &covar, TVectorD &trace)
00079 {
00080 TMatrixDSym tmp;
00081 covar.GetSub(0, covar.GetNrows() - 2, tmp);
00082 TDecompSVD decCovar(tmp);
00083 trace = decCovar.GetSig();
00084 return decCovar.GetU();
00085 }
00086
00087 void LeastSquares::calculate()
00088 {
00089 double N = coeffs(n + 1, n + 1);
00090
00091 for(unsigned int i = 0; i <= n; i++) {
00092 double M = coeffs(n + 1, i);
00093 for(unsigned int j = 0; j <= n; j++)
00094 covar(i, j) = coeffs(i, j) * N - M * coeffs(n + 1, j);
00095 }
00096
00097 for(unsigned int i = 0; i <= n; i++) {
00098 double c = covar(i, i);
00099 variance[i] = c > 0.0 ? std::sqrt(c) : 0.0;
00100 }
00101
00102 for(unsigned int i = 0; i <= n; i++) {
00103 double M = variance[i];
00104 for(unsigned int j = 0; j <= n; j++) {
00105 double v = M * variance[j];
00106 double w = covar(i, j);
00107
00108 corr(i, j) = (v >= 1.0e-9) ? (w / v) : (i == j);
00109 }
00110 }
00111
00112 weights = solveFisher(coeffs);
00113 rotation = solveRotation(covar, trace);
00114 }
00115
00116 std::vector<double> LeastSquares::getWeights() const
00117 {
00118 std::vector<double> results;
00119 results.reserve(n);
00120
00121 for(unsigned int i = 0; i < n; i++)
00122 results.push_back(weights[i]);
00123
00124 return results;
00125 }
00126
00127 std::vector<double> LeastSquares::getMeans() const
00128 {
00129 std::vector<double> results;
00130 results.reserve(n);
00131
00132 double N = coeffs(n + 1, n + 1);
00133 for(unsigned int i = 0; i < n; i++)
00134 results.push_back(coeffs(n + 1, i) / N);
00135
00136 return results;
00137 }
00138
00139 double LeastSquares::getConstant() const
00140 {
00141 return weights[n];
00142 }
00143
00144 static void loadMatrix(DOMElement *elem, unsigned int n, TMatrixDBase &matrix)
00145 {
00146 if (std::strcmp(XMLSimpleStr(elem->getNodeName()),
00147 "matrix") != 0)
00148 throw cms::Exception("LeastSquares")
00149 << "Expected matrix in data file."
00150 << std::endl;
00151
00152 unsigned int row = 0;
00153 for(DOMNode *node = elem->getFirstChild();
00154 node; node = node->getNextSibling()) {
00155 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00156 continue;
00157
00158 if (std::strcmp(XMLSimpleStr(node->getNodeName()),
00159 "row") != 0)
00160 throw cms::Exception("LeastSquares")
00161 << "Expected row tag in data file."
00162 << std::endl;
00163
00164 if (row >= n)
00165 throw cms::Exception("LeastSquares")
00166 << "Too many rows in data file." << std::endl;
00167
00168 elem = static_cast<DOMElement*>(node);
00169
00170 unsigned int col = 0;
00171 for(DOMNode *subNode = elem->getFirstChild();
00172 subNode; subNode = subNode->getNextSibling()) {
00173 if (subNode->getNodeType() != DOMNode::ELEMENT_NODE)
00174 continue;
00175
00176 if (std::strcmp(XMLSimpleStr(subNode->getNodeName()),
00177 "value") != 0)
00178 throw cms::Exception("LeastSquares")
00179 << "Expected value tag in data file."
00180 << std::endl;
00181
00182 if (col >= n)
00183 throw cms::Exception("LeastSquares")
00184 << "Too many columns in data file."
00185 << std::endl;
00186
00187 matrix(row, col) =
00188 XMLDocument::readContent<double>(subNode);
00189 col++;
00190 }
00191
00192 if (col != n)
00193 throw cms::Exception("LeastSquares")
00194 << "Missing columns in data file."
00195 << std::endl;
00196 row++;
00197 }
00198
00199 if (row != n)
00200 throw cms::Exception("LeastSquares")
00201 << "Missing rows in data file."
00202 << std::endl;
00203 }
00204
00205 static void loadVector(DOMElement *elem, unsigned int n, TVectorD &vector)
00206 {
00207 if (std::strcmp(XMLSimpleStr(elem->getNodeName()),
00208 "vector") != 0)
00209 throw cms::Exception("LeastSquares")
00210 << "Expected matrix in data file."
00211 << std::endl;
00212
00213 unsigned int col = 0;
00214 for(DOMNode *node = elem->getFirstChild();
00215 node; node = node->getNextSibling()) {
00216 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00217 continue;
00218
00219 if (std::strcmp(XMLSimpleStr(node->getNodeName()),
00220 "value") != 0)
00221 throw cms::Exception("LeastSquares")
00222 << "Expected value tag in data file."
00223 << std::endl;
00224
00225 if (col >= n)
00226 throw cms::Exception("LeastSquares")
00227 << "Too many columns in data file."
00228 << std::endl;
00229
00230 vector(col) = XMLDocument::readContent<double>(node);
00231 col++;
00232 }
00233
00234 if (col != n)
00235 throw cms::Exception("LeastSquares")
00236 << "Missing columns in data file."
00237 << std::endl;
00238 }
00239
00240 static DOMElement *saveMatrix(DOMDocument *doc, unsigned int n,
00241 const TMatrixDBase &matrix)
00242 {
00243 DOMElement *root = doc->createElement(XMLUniStr("matrix"));
00244 XMLDocument::writeAttribute<unsigned int>(root, "size", n);
00245
00246 for(unsigned int i = 0; i < n; i++) {
00247 DOMElement *row = doc->createElement(XMLUniStr("row"));
00248 root->appendChild(row);
00249
00250 for(unsigned int j = 0; j < n; j++) {
00251 DOMElement *value =
00252 doc->createElement(XMLUniStr("value"));
00253 row->appendChild(value);
00254
00255 XMLDocument::writeContent<double>(value, doc,
00256 matrix(i, j));
00257 }
00258 }
00259
00260 return root;
00261 }
00262
00263 static DOMElement *saveVector(DOMDocument *doc, unsigned int n,
00264 const TVectorD &vector)
00265 {
00266 DOMElement *root = doc->createElement(XMLUniStr("vector"));
00267 XMLDocument::writeAttribute<unsigned int>(root, "size", n);
00268
00269 for(unsigned int i = 0; i < n; i++) {
00270 DOMElement *value =
00271 doc->createElement(XMLUniStr("value"));
00272 root->appendChild(value);
00273
00274 XMLDocument::writeContent<double>(value, doc, vector(i));
00275 }
00276
00277 return root;
00278 }
00279
00280 void LeastSquares::load(DOMElement *elem)
00281 {
00282 if (std::strcmp(XMLSimpleStr(elem->getNodeName()),
00283 "LinearAnalysis") != 0)
00284 throw cms::Exception("LeastSquares")
00285 << "Expected LinearAnalysis in data file."
00286 << std::endl;
00287
00288 unsigned int version = XMLDocument::readAttribute<unsigned int>(
00289 elem, "version", 1);
00290
00291 enum Position {
00292 POS_COEFFS, POS_COVAR, POS_CORR, POS_ROTATION,
00293 POS_SUMS, POS_WEIGHTS, POS_VARIANCE, POS_TRACE, POS_DONE
00294 } pos = POS_COEFFS;
00295
00296 for(DOMNode *node = elem->getFirstChild();
00297 node; node = node->getNextSibling()) {
00298 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00299 continue;
00300
00301 DOMElement *elem = static_cast<DOMElement*>(node);
00302
00303 switch(pos) {
00304 case POS_COEFFS:
00305 if (version < 2) {
00306 loadMatrix(elem, n + 1, coeffs);
00307 coeffs.ResizeTo(n + 2, n + 2);
00308 for(unsigned int i = 0; i <= n; i++) {
00309 coeffs(n + 1, i) = coeffs(n, i);
00310 coeffs(i, n + 1) = coeffs(i, n);
00311 }
00312 coeffs(n + 1, n + 1) = coeffs(n + 1, n);
00313 } else
00314 loadMatrix(elem, n + 2, coeffs);
00315 break;
00316 case POS_COVAR:
00317 if (version < 2)
00318 loadMatrix(elem, n, covar);
00319 else
00320 loadMatrix(elem, n + 1, covar);
00321 break;
00322 case POS_CORR:
00323 if (version < 2)
00324 loadMatrix(elem, n, corr);
00325 else
00326 loadMatrix(elem, n + 1, corr);
00327 break;
00328 case POS_ROTATION:
00329 loadMatrix(elem, n, rotation);
00330 break;
00331 case POS_SUMS:
00332 if (version < 2) {
00333 TVectorD tmp(n + 1);
00334 loadVector(elem, n + 1, tmp);
00335
00336 double M = coeffs(n + 1, n);
00337 double N = coeffs(n + 1, n + 1);
00338
00339 for(unsigned int i = 0; i <= n; i++) {
00340 double v = coeffs(n, i) * N -
00341 M * coeffs(n + 1, i);
00342 double w = coeffs(n, i) * N - v;
00343
00344 covar(n, i) = w;
00345 covar(i, n) = w;
00346 }
00347
00348 break;
00349 } else
00350 pos = (Position)(pos + 1);
00351 case POS_WEIGHTS:
00352 loadVector(elem, n + 1, weights);
00353 break;
00354 case POS_VARIANCE:
00355 if (version < 2) {
00356 loadVector(elem, n, variance);
00357
00358 double M = covar(n, n);
00359 M = M > 0.0 ? std::sqrt(M) : 0.0;
00360 variance[n] = M;
00361
00362 for(unsigned int i = 0; i <= n; i++) {
00363 double v = M * variance[i];
00364 double w = covar(n, i);
00365 double c = (v >= 1.0e-9)
00366 ? (w / v) : (i == n);
00367
00368 corr(n, i) = c;
00369 corr(i, n) = c;
00370 }
00371 } else
00372 loadVector(elem, n + 1, variance);
00373 break;
00374 case POS_TRACE:
00375 loadVector(elem, n, trace);
00376 break;
00377 default:
00378 throw cms::Exception("LeastSquares")
00379 << "Superfluous content in data file."
00380 << std::endl;
00381 }
00382
00383 pos = (Position)(pos + 1);
00384 }
00385
00386 if (pos != POS_DONE)
00387 throw cms::Exception("LeastSquares")
00388 << "Missing objects in data file."
00389 << std::endl;
00390 }
00391
00392 DOMElement *LeastSquares::save(DOMDocument *doc) const
00393 {
00394 DOMElement *root = doc->createElement(XMLUniStr("LinearAnalysis"));
00395 XMLDocument::writeAttribute<unsigned int>(root, "version", 2);
00396 XMLDocument::writeAttribute<unsigned int>(root, "size", n);
00397
00398 root->appendChild(saveMatrix(doc, n + 2, coeffs));
00399 root->appendChild(saveMatrix(doc, n + 1, covar));
00400 root->appendChild(saveMatrix(doc, n + 1, corr));
00401 root->appendChild(saveMatrix(doc, n, rotation));
00402 root->appendChild(saveVector(doc, n + 1, weights));
00403 root->appendChild(saveVector(doc, n + 1, variance));
00404 root->appendChild(saveVector(doc, n, trace));
00405
00406 return root;
00407 }
00408
00409 }