CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_5_3_0/src/PhysicsTools/MVATrainer/src/LeastSquares.cc

Go to the documentation of this file.
00001 #include <iostream>
00002 #include <iomanip>
00003 #include <cstring>
00004 #include <vector>
00005 #include <cmath>
00006 
00007 #include <TMatrixD.h>
00008 #include <TVectorD.h>
00009 #include <TDecompSVD.h>
00010 
00011 #include "FWCore/Utilities/interface/Exception.h"
00012 
00013 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00014 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00015 #include "PhysicsTools/MVATrainer/interface/XMLUniStr.h"
00016 #include "PhysicsTools/MVATrainer/interface/LeastSquares.h"
00017 
00018 XERCES_CPP_NAMESPACE_USE
00019 
00020 namespace PhysicsTools {
00021 
00022 LeastSquares::LeastSquares(unsigned int n) :
00023         coeffs(n + 2), covar(n + 1), corr(n + 1), rotation(n, n),
00024         weights(n + 1), variance(n + 1), trace(n), n(n)
00025 {
00026 }
00027 
00028 LeastSquares::~LeastSquares()
00029 {
00030 }
00031 
00032 void LeastSquares::add(const std::vector<double> &values,
00033                        double dest, double weight)
00034 {
00035         if (values.size() != n)
00036                 throw cms::Exception("LeastSquares")
00037                         << "add(): invalid array size!" << std::endl;
00038 
00039         for(unsigned int i = 0; i < n; i++) {
00040                 for(unsigned int j = 0; j < n; j++)
00041                         coeffs(i, j) += values[i] * values[j] * weight;
00042 
00043                 coeffs(n, i) += values[i] * dest * weight;
00044                 coeffs(i, n) += values[i] * dest * weight;
00045                 coeffs(n + 1, i) += values[i] * weight;
00046                 coeffs(i, n + 1) += values[i] * weight;
00047         } 
00048 
00049         coeffs(n, n) += dest * dest * weight;
00050         coeffs(n + 1, n) += dest * weight;
00051         coeffs(n, n + 1) += dest * weight;
00052         coeffs(n + 1, n + 1) += weight;
00053 }
00054 
00055 void LeastSquares::add(const LeastSquares &other, double weight)
00056 {
00057         if (other.getSize() != n)
00058                 throw cms::Exception("LeastSquares")
00059                         << "add(): invalid array size!" << std::endl;
00060 
00061         coeffs += weight * other.coeffs;
00062 }
00063 
00064 TVectorD LeastSquares::solveFisher(const TMatrixDSym &coeffs)
00065 {
00066         unsigned int n = coeffs.GetNrows() - 2;
00067 
00068         TMatrixDSym tmp;
00069         coeffs.GetSub(0, n, tmp);
00070         tmp[n] = TVectorD(n + 1, coeffs[n + 1].GetPtr());
00071         tmp(n, n) = coeffs(n + 1, n + 1);
00072 
00073         TDecompSVD decCoeffs(tmp);
00074         bool ok;
00075         return decCoeffs.Solve(TVectorD(n + 1, coeffs[n].GetPtr()), ok);
00076 }
00077 
00078 TMatrixD LeastSquares::solveRotation(const TMatrixDSym &covar, TVectorD &trace)
00079 {
00080         TMatrixDSym tmp;
00081         covar.GetSub(0, covar.GetNrows() - 2, tmp);
00082         TDecompSVD decCovar(tmp);
00083         trace = decCovar.GetSig();
00084         return decCovar.GetU();
00085 }
00086 
00087 void LeastSquares::calculate()
00088 {
00089         double N = coeffs(n + 1, n + 1);
00090 
00091         for(unsigned int i = 0; i <= n; i++) {
00092                 double M = coeffs(n + 1, i);
00093                 for(unsigned int j = 0; j <= n; j++)
00094                         covar(i, j) = coeffs(i, j) * N - M * coeffs(n + 1, j);
00095         }
00096 
00097         for(unsigned int i = 0; i <= n; i++) {
00098                 double c = covar(i, i);
00099                 variance[i] = c > 0.0 ? std::sqrt(c) : 0.0;
00100         }
00101 
00102         for(unsigned int i = 0; i <= n; i++) {
00103                 double M = variance[i];
00104                 for(unsigned int j = 0; j <= n; j++) {
00105                         double v = M * variance[j];
00106                         double w = covar(i, j);
00107 
00108                         corr(i, j) = (v >= 1.0e-9) ? (w / v) : (i == j);
00109                 }
00110         }
00111 
00112         weights = solveFisher(coeffs);
00113         rotation = solveRotation(covar, trace);
00114 }
00115 
00116 std::vector<double> LeastSquares::getWeights() const
00117 {
00118         std::vector<double> results;
00119         results.reserve(n);
00120         
00121         for(unsigned int i = 0; i < n; i++)
00122                 results.push_back(weights[i]);
00123         
00124         return results;
00125 }
00126 
00127 std::vector<double> LeastSquares::getMeans() const
00128 {
00129         std::vector<double> results;
00130         results.reserve(n);
00131 
00132         double N = coeffs(n + 1, n + 1);
00133         for(unsigned int i = 0; i < n; i++)
00134                 results.push_back(coeffs(n + 1, i) / N);
00135 
00136         return results;
00137 }
00138 
00139 double LeastSquares::getConstant() const
00140 {
00141         return weights[n];
00142 }
00143 
00144 static void loadMatrix(DOMElement *elem, unsigned int n, TMatrixDBase &matrix)
00145 {
00146         if (std::strcmp(XMLSimpleStr(elem->getNodeName()),
00147                                      "matrix") != 0)
00148                 throw cms::Exception("LeastSquares")
00149                         << "Expected matrix in data file."
00150                         << std::endl;
00151 
00152         unsigned int row = 0;
00153         for(DOMNode *node = elem->getFirstChild();
00154             node; node = node->getNextSibling()) {
00155                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00156                         continue;
00157 
00158                 if (std::strcmp(XMLSimpleStr(node->getNodeName()),
00159                                              "row") != 0)
00160                         throw cms::Exception("LeastSquares")
00161                                 << "Expected row tag in data file."
00162                                 << std::endl;
00163 
00164                 if (row >= n)
00165                         throw cms::Exception("LeastSquares")
00166                                 << "Too many rows in data file." << std::endl;
00167 
00168                 elem = static_cast<DOMElement*>(node);
00169 
00170                 unsigned int col = 0;
00171                 for(DOMNode *subNode = elem->getFirstChild();
00172                     subNode; subNode = subNode->getNextSibling()) {
00173                         if (subNode->getNodeType() != DOMNode::ELEMENT_NODE)
00174                                 continue;
00175 
00176                         if (std::strcmp(XMLSimpleStr(subNode->getNodeName()),
00177                                                      "value") != 0)
00178                         throw cms::Exception("LeastSquares")
00179                                 << "Expected value tag in data file."
00180                                 << std::endl;
00181 
00182                         if (col >= n)
00183                                 throw cms::Exception("LeastSquares")
00184                                         << "Too many columns in data file."
00185                                         << std::endl;
00186 
00187                         matrix(row, col) =
00188                                 XMLDocument::readContent<double>(subNode);
00189                         col++;
00190                 }
00191 
00192                 if (col != n)
00193                         throw cms::Exception("LeastSquares")
00194                                 << "Missing columns in data file."
00195                                 << std::endl;
00196                 row++;
00197         }
00198 
00199         if (row != n)
00200                 throw cms::Exception("LeastSquares")
00201                         << "Missing rows in data file."
00202                         << std::endl;
00203 }
00204 
00205 static void loadVector(DOMElement *elem, unsigned int n, TVectorD &vector)
00206 {
00207         if (std::strcmp(XMLSimpleStr(elem->getNodeName()),
00208                                      "vector") != 0)
00209                 throw cms::Exception("LeastSquares")
00210                         << "Expected matrix in data file."
00211                         << std::endl;
00212 
00213         unsigned int col = 0;
00214         for(DOMNode *node = elem->getFirstChild();
00215             node; node = node->getNextSibling()) {
00216                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00217                         continue;
00218 
00219                 if (std::strcmp(XMLSimpleStr(node->getNodeName()),
00220                                              "value") != 0)
00221                 throw cms::Exception("LeastSquares")
00222                         << "Expected value tag in data file."
00223                         << std::endl;
00224 
00225                 if (col >= n)
00226                         throw cms::Exception("LeastSquares")
00227                                 << "Too many columns in data file."
00228                                 << std::endl;
00229 
00230                 vector(col) = XMLDocument::readContent<double>(node);
00231                 col++;
00232         }
00233 
00234         if (col != n)
00235                 throw cms::Exception("LeastSquares")
00236                         << "Missing columns in data file."
00237                         << std::endl;
00238 }
00239 
00240 static DOMElement *saveMatrix(DOMDocument *doc, unsigned int n,
00241                               const TMatrixDBase &matrix)
00242 {
00243         DOMElement *root = doc->createElement(XMLUniStr("matrix"));
00244         XMLDocument::writeAttribute<unsigned int>(root, "size", n);
00245 
00246         for(unsigned int i = 0; i < n; i++) {
00247                 DOMElement *row = doc->createElement(XMLUniStr("row"));
00248                 root->appendChild(row);
00249 
00250                 for(unsigned int j = 0; j < n; j++) {
00251                         DOMElement *value =
00252                                 doc->createElement(XMLUniStr("value"));
00253                         row->appendChild(value);
00254 
00255                         XMLDocument::writeContent<double>(value, doc,
00256                                                           matrix(i, j));
00257                 }
00258         }
00259 
00260         return root;
00261 }
00262 
00263 static DOMElement *saveVector(DOMDocument *doc, unsigned int n,
00264                               const TVectorD &vector)
00265 {
00266         DOMElement *root = doc->createElement(XMLUniStr("vector"));
00267         XMLDocument::writeAttribute<unsigned int>(root, "size", n);
00268 
00269         for(unsigned int i = 0; i < n; i++) {
00270                 DOMElement *value =
00271                         doc->createElement(XMLUniStr("value"));
00272                 root->appendChild(value);
00273 
00274                 XMLDocument::writeContent<double>(value, doc, vector(i));
00275         }
00276 
00277         return root;
00278 }
00279 
00280 void LeastSquares::load(DOMElement *elem)
00281 {
00282         if (std::strcmp(XMLSimpleStr(elem->getNodeName()),
00283                                      "LinearAnalysis") != 0)
00284                 throw cms::Exception("LeastSquares")
00285                         << "Expected LinearAnalysis in data file."
00286                         << std::endl;
00287 
00288         unsigned int version = XMLDocument::readAttribute<unsigned int>(
00289                                                         elem, "version", 1);
00290 
00291         enum Position {
00292                 POS_COEFFS, POS_COVAR, POS_CORR, POS_ROTATION,
00293                 POS_SUMS, POS_WEIGHTS, POS_VARIANCE, POS_TRACE, POS_DONE
00294         } pos = POS_COEFFS;
00295 
00296         for(DOMNode *node = elem->getFirstChild();
00297             node; node = node->getNextSibling()) {
00298                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00299                         continue;
00300 
00301                 DOMElement *elem = static_cast<DOMElement*>(node);
00302 
00303                 switch(pos) {
00304                     case POS_COEFFS:
00305                         if (version < 2) {
00306                                 loadMatrix(elem, n + 1, coeffs);
00307                                 coeffs.ResizeTo(n + 2, n + 2);
00308                                 for(unsigned int i = 0; i <= n; i++) {
00309                                         coeffs(n + 1, i) = coeffs(n, i);
00310                                         coeffs(i, n + 1) = coeffs(i, n);
00311                                 }
00312                                 coeffs(n + 1, n + 1) = coeffs(n + 1, n);
00313                         } else
00314                                 loadMatrix(elem, n + 2, coeffs);
00315                         break;
00316                     case POS_COVAR:
00317                         if (version < 2)
00318                                 loadMatrix(elem, n, covar);
00319                         else
00320                                 loadMatrix(elem, n + 1, covar);
00321                         break;
00322                     case POS_CORR:
00323                         if (version < 2)
00324                                 loadMatrix(elem, n, corr);
00325                         else
00326                                 loadMatrix(elem, n + 1, corr);
00327                         break;
00328                     case POS_ROTATION:
00329                         loadMatrix(elem, n, rotation);
00330                         break;
00331                     case POS_SUMS:
00332                         if (version < 2) {
00333                                 TVectorD tmp(n + 1);
00334                                 loadVector(elem, n + 1, tmp);
00335 
00336                                 double M = coeffs(n + 1, n);
00337                                 double N = coeffs(n + 1, n + 1);
00338 
00339                                 for(unsigned int i = 0; i <= n; i++) {
00340                                         double v = coeffs(n, i) * N -
00341                                                    M * coeffs(n + 1, i);
00342                                         double w = coeffs(n, i) * N - v;
00343 
00344                                         covar(n, i) = w;
00345                                         covar(i, n) = w;
00346                                 }
00347 
00348                                 break;
00349                         } else
00350                                 pos = (Position)(pos + 1);
00351                     case POS_WEIGHTS:
00352                         loadVector(elem, n + 1, weights);
00353                         break;
00354                     case POS_VARIANCE:
00355                         if (version < 2) {
00356                                 loadVector(elem, n, variance);
00357 
00358                                 double M = covar(n, n);
00359                                 M = M > 0.0 ? std::sqrt(M) : 0.0;
00360                                 variance[n] = M;
00361 
00362                                 for(unsigned int i = 0; i <= n; i++) {
00363                                         double v = M * variance[i];
00364                                         double w = covar(n, i);
00365                                         double c = (v >= 1.0e-9)
00366                                                         ? (w / v) : (i == n);
00367 
00368                                         corr(n, i) = c;
00369                                         corr(i, n) = c;
00370                                 }
00371                         } else
00372                                 loadVector(elem, n + 1, variance);
00373                         break;
00374                     case POS_TRACE:
00375                         loadVector(elem, n, trace);
00376                         break;
00377                     default:
00378                         throw cms::Exception("LeastSquares")
00379                                 << "Superfluous content in data file."
00380                                 << std::endl;
00381                 }
00382 
00383                 pos = (Position)(pos + 1);
00384         }
00385 
00386         if (pos != POS_DONE)
00387                 throw cms::Exception("LeastSquares")
00388                         << "Missing objects in data file."
00389                         << std::endl;
00390 }
00391 
00392 DOMElement *LeastSquares::save(DOMDocument *doc) const
00393 {
00394         DOMElement *root = doc->createElement(XMLUniStr("LinearAnalysis"));
00395         XMLDocument::writeAttribute<unsigned int>(root, "version", 2);
00396         XMLDocument::writeAttribute<unsigned int>(root, "size", n);
00397 
00398         root->appendChild(saveMatrix(doc, n + 2, coeffs));
00399         root->appendChild(saveMatrix(doc, n + 1, covar));
00400         root->appendChild(saveMatrix(doc, n + 1, corr));
00401         root->appendChild(saveMatrix(doc, n, rotation));
00402         root->appendChild(saveVector(doc, n + 1, weights));
00403         root->appendChild(saveVector(doc, n + 1, variance));
00404         root->appendChild(saveVector(doc, n, trace));
00405 
00406         return root;
00407 }
00408 
00409 } // namespace PhysicsTools