00001 #include <iostream>
00002 #include <vector>
00003 #include <memory>
00004
00005 #include <xercesc/dom/DOM.hpp>
00006
00007 #include "FWCore/Utilities/interface/Exception.h"
00008
00009 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00010
00011 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00012 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00013 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00014 #include "PhysicsTools/MVATrainer/interface/LeastSquares.h"
00015
00016 XERCES_CPP_NAMESPACE_USE
00017
00018 using namespace PhysicsTools;
00019
00020 namespace {
00021
00022 class ProcLinear : public TrainProcessor {
00023 public:
00024 typedef TrainProcessor::Registry<ProcLinear>::Type Registry;
00025
00026 ProcLinear(const char *name, const AtomicId *id,
00027 MVATrainer *trainer);
00028 virtual ~ProcLinear();
00029
00030 virtual void configure(DOMElement *elem);
00031 virtual Calibration::VarProcessor *getCalibration() const;
00032
00033 virtual void trainBegin();
00034 virtual void trainData(const std::vector<double> *values,
00035 bool target, double weight);
00036 virtual void trainEnd();
00037
00038 virtual bool load();
00039 virtual void save();
00040
00041 protected:
00042 virtual void *requestObject(const std::string &name) const;
00043
00044 private:
00045 enum Iteration {
00046 ITER_FILL,
00047 ITER_DONE
00048 } iteration;
00049
00050 std::auto_ptr<LeastSquares> ls;
00051 std::vector<double> vars;
00052 };
00053
00054 static ProcLinear::Registry registry("ProcLinear");
00055
00056 ProcLinear::ProcLinear(const char *name, const AtomicId *id,
00057 MVATrainer *trainer) :
00058 TrainProcessor(name, id, trainer),
00059 iteration(ITER_FILL)
00060 {
00061 }
00062
00063 ProcLinear::~ProcLinear()
00064 {
00065 }
00066
00067 void ProcLinear::configure(DOMElement *elem)
00068 {
00069 ls = std::auto_ptr<LeastSquares>(new LeastSquares(getInputs().size()));
00070 }
00071
00072 Calibration::VarProcessor *ProcLinear::getCalibration() const
00073 {
00074 Calibration::ProcLinear *calib = new Calibration::ProcLinear;
00075
00076 calib->coeffs = ls->getWeights();
00077 calib->offset = ls->getConstant();
00078
00079 return calib;
00080 }
00081
00082 void ProcLinear::trainBegin()
00083 {
00084 if (iteration == ITER_FILL)
00085 vars.resize(ls->getSize());
00086 }
00087
00088 void ProcLinear::trainData(const std::vector<double> *values,
00089 bool target, double weight)
00090 {
00091 if (iteration != ITER_FILL)
00092 return;
00093
00094 for(unsigned int i = 0; i < ls->getSize(); i++, values++)
00095 vars[i] = values->front();
00096
00097 ls->add(vars, target, weight);
00098 }
00099
00100 void ProcLinear::trainEnd()
00101 {
00102 switch(iteration) {
00103 case ITER_FILL:
00104 vars.clear();
00105 ls->calculate();
00106
00107 iteration = ITER_DONE;
00108 trained = true;
00109 break;
00110 default:
00111 ;
00112 }
00113 }
00114
00115 void *ProcLinear::requestObject(const std::string &name) const
00116 {
00117 if (name == "linearAnalyzer")
00118 return static_cast<void*>(ls.get());
00119
00120 return 0;
00121 }
00122
00123 bool ProcLinear::load()
00124 {
00125 std::string filename = trainer->trainFileName(this, "xml");
00126 if (!exists(filename))
00127 return false;
00128
00129 XMLDocument xml(filename);
00130 DOMElement *elem = xml.getRootNode();
00131 if (std::strcmp(XMLSimpleStr(elem->getNodeName()), "ProcLinear") != 0)
00132 throw cms::Exception("ProcLinear")
00133 << "XML training data file has bad root node."
00134 << std::endl;
00135
00136 DOMNode *node = elem->getFirstChild();
00137 while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
00138 node = node->getNextSibling();
00139
00140 if (!node)
00141 throw cms::Exception("ProcLinear")
00142 << "Train data file empty." << std::endl;
00143
00144 ls->load(static_cast<DOMElement*>(node));
00145
00146 node = elem->getNextSibling();
00147 while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
00148 node = node->getNextSibling();
00149
00150 if (node)
00151 throw cms::Exception("ProcLinear")
00152 << "Train data file contains superfluous tags."
00153 << std::endl;
00154
00155 iteration = ITER_DONE;
00156 trained = true;
00157 return true;
00158 }
00159
00160 void ProcLinear::save()
00161 {
00162 XMLDocument xml(trainer->trainFileName(this, "xml"), true);
00163 DOMDocument *doc = xml.createDocument("ProcLinear");
00164
00165 xml.getRootNode()->appendChild(ls->save(doc));
00166 }
00167
00168 }