CMS 3D CMS Logo

ProcLinear.cc

Go to the documentation of this file.
00001 #include <iostream>
00002 #include <vector>
00003 #include <memory>
00004 
00005 #include <xercesc/dom/DOM.hpp>
00006 
00007 #include "FWCore/Utilities/interface/Exception.h"
00008 
00009 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00010 
00011 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00012 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00013 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00014 #include "PhysicsTools/MVATrainer/interface/LeastSquares.h"
00015 
00016 XERCES_CPP_NAMESPACE_USE
00017 
00018 using namespace PhysicsTools;
00019 
00020 namespace { // anonymous
00021 
00022 class ProcLinear : public TrainProcessor {
00023     public:
00024         typedef TrainProcessor::Registry<ProcLinear>::Type Registry;
00025 
00026         ProcLinear(const char *name, const AtomicId *id,
00027                    MVATrainer *trainer);
00028         virtual ~ProcLinear();
00029 
00030         virtual void configure(DOMElement *elem);
00031         virtual Calibration::VarProcessor *getCalibration() const;
00032 
00033         virtual void trainBegin();
00034         virtual void trainData(const std::vector<double> *values,
00035                                bool target, double weight);
00036         virtual void trainEnd();
00037 
00038         virtual bool load();
00039         virtual void save();
00040 
00041     protected:
00042         virtual void *requestObject(const std::string &name) const;
00043 
00044     private:
00045         enum Iteration {
00046                 ITER_FILL,
00047                 ITER_DONE
00048         } iteration;
00049 
00050         std::auto_ptr<LeastSquares>     ls;
00051         std::vector<double>             vars;
00052 };
00053 
00054 static ProcLinear::Registry registry("ProcLinear");
00055 
00056 ProcLinear::ProcLinear(const char *name, const AtomicId *id,
00057                              MVATrainer *trainer) :
00058         TrainProcessor(name, id, trainer),
00059         iteration(ITER_FILL)
00060 {
00061 }
00062 
00063 ProcLinear::~ProcLinear()
00064 {
00065 }
00066 
00067 void ProcLinear::configure(DOMElement *elem)
00068 {
00069         ls = std::auto_ptr<LeastSquares>(new LeastSquares(getInputs().size()));
00070 }
00071 
00072 Calibration::VarProcessor *ProcLinear::getCalibration() const
00073 {
00074         Calibration::ProcLinear *calib = new Calibration::ProcLinear;
00075 
00076         calib->coeffs = ls->getWeights();
00077         calib->offset = ls->getConstant();
00078 
00079         return calib;
00080 }
00081 
00082 void ProcLinear::trainBegin()
00083 {
00084         if (iteration == ITER_FILL)
00085                 vars.resize(ls->getSize());
00086 }
00087 
00088 void ProcLinear::trainData(const std::vector<double> *values,
00089                            bool target, double weight)
00090 {
00091         if (iteration != ITER_FILL)
00092                 return;
00093 
00094         for(unsigned int i = 0; i < ls->getSize(); i++, values++)
00095                 vars[i] = values->front();
00096 
00097         ls->add(vars, target, weight);
00098 }
00099 
00100 void ProcLinear::trainEnd()
00101 {
00102         switch(iteration) {
00103             case ITER_FILL:
00104                 vars.clear();
00105                 ls->calculate();
00106 
00107                 iteration = ITER_DONE;
00108                 trained = true;
00109                 break;
00110             default:
00111                 /* shut up */;
00112         }
00113 }
00114 
00115 void *ProcLinear::requestObject(const std::string &name) const
00116 {
00117         if (name == "linearAnalyzer")
00118                 return static_cast<void*>(ls.get());
00119 
00120         return 0;
00121 }
00122 
00123 bool ProcLinear::load()
00124 {
00125         std::string filename = trainer->trainFileName(this, "xml");
00126         if (!exists(filename))
00127                 return false;
00128 
00129         XMLDocument xml(filename);
00130         DOMElement *elem = xml.getRootNode();
00131         if (std::strcmp(XMLSimpleStr(elem->getNodeName()), "ProcLinear") != 0)
00132                 throw cms::Exception("ProcLinear")
00133                         << "XML training data file has bad root node."
00134                         << std::endl;
00135 
00136         DOMNode *node = elem->getFirstChild();
00137         while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
00138                 node = node->getNextSibling();
00139 
00140         if (!node)
00141                 throw cms::Exception("ProcLinear")
00142                         << "Train data file empty." << std::endl;
00143 
00144         ls->load(static_cast<DOMElement*>(node));
00145 
00146         node = elem->getNextSibling();
00147         while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
00148                 node = node->getNextSibling();
00149 
00150         if (node)
00151                 throw cms::Exception("ProcLinear")
00152                         << "Train data file contains superfluous tags."
00153                         << std::endl;
00154 
00155         iteration = ITER_DONE;
00156         trained = true;
00157         return true;
00158 }
00159 
00160 void ProcLinear::save()
00161 {
00162         XMLDocument xml(trainer->trainFileName(this, "xml"), true);
00163         DOMDocument *doc = xml.createDocument("ProcLinear");
00164 
00165         xml.getRootNode()->appendChild(ls->save(doc));
00166 }
00167 
00168 } // anonymous namespace

Generated on Tue Jun 9 17:41:28 2009 for CMSSW by  doxygen 1.5.4