CMS 3D CMS Logo

ProcLinear.cc
Go to the documentation of this file.
1 #include <iostream>
2 #include <vector>
3 #include <memory>
4 
5 #include <xercesc/dom/DOM.hpp>
6 
8 
10 
15 
17 
18 using namespace PhysicsTools;
19 
20 namespace { // anonymous
21 
22 class ProcLinear : public TrainProcessor {
23  public:
25 
26  ProcLinear(const char *name, const AtomicId *id,
27  MVATrainer *trainer);
28  ~ProcLinear() override;
29 
30  void configure(DOMElement *elem) override;
31  Calibration::VarProcessor *getCalibration() const override;
32 
33  void trainBegin() override;
34  void trainData(const std::vector<double> *values,
35  bool target, double weight) override;
36  void trainEnd() override;
37 
38  bool load() override;
39  void save() override;
40 
41  protected:
42  void *requestObject(const std::string &name) const override;
43 
44  private:
45  enum Iteration {
46  ITER_FILL,
47  ITER_DONE
48  } iteration;
49 
50  std::unique_ptr<LeastSquares> ls;
51  std::vector<double> vars;
52  std::vector<double> coefficients;
53  double theoffset;
54 };
55 
56 ProcLinear::Registry registry("ProcLinear");
57 
58 ProcLinear::ProcLinear(const char *name, const AtomicId *id,
59  MVATrainer *trainer) :
60  TrainProcessor(name, id, trainer),
61  iteration(ITER_FILL)
62 {
63 }
64 
65 ProcLinear::~ProcLinear()
66 {
67 }
68 
69 void ProcLinear::configure(DOMElement *elem)
70 {
71  ls = std::unique_ptr<LeastSquares>(new LeastSquares(getInputs().size()));
72 
73  DOMNode *node = elem->getFirstChild();
74  while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
75  node = node->getNextSibling();
76 
77  if (!node)
78  return;
79 
80  if (std::strcmp(XMLSimpleStr(node->getNodeName()), "coefficients") != 0)
81  throw cms::Exception("ProcLinear")
82  << "Expected coefficients tag in config section."
83  << std::endl;
84 
85  elem = static_cast<DOMElement*>(node);
86 
87  //if (XMLDocument::hasAttribute(elem, "offset"))
88  theoffset= XMLDocument::readAttribute<double>(elem, "offset", 0.0);
89  if (XMLDocument::hasAttribute(elem, "coeff1"))
90  coefficients.push_back(XMLDocument::readAttribute<double>(elem, "coeff1", 1.0));
91  if (XMLDocument::hasAttribute(elem, "coeff2"))
92  coefficients.push_back(XMLDocument::readAttribute<double>(elem, "coeff2", 1.0));
93 
94 }
95 
96 Calibration::VarProcessor *ProcLinear::getCalibration() const
97 {
99  /*std::vector<double> a;
100  a.push_back(0.75);
101  a.push_back(0.25);
102  calib->coeffs = a;
103  calib->offset = 0.0;
104  */
105  calib->coeffs = coefficients;
106  calib->offset = theoffset;
107 
108 // calib->coeffs = ls->getWeights();
109 // calib->offset = ls->getConstant();
110  return calib;
111 }
112 
113 void ProcLinear::trainBegin()
114 {
115  if (iteration == ITER_FILL)
116  vars.resize(ls->getSize());
117 }
118 
119 void ProcLinear::trainData(const std::vector<double> *values,
120  bool target, double weight)
121 {
122  if (iteration != ITER_FILL)
123  return;
124 
125  for(unsigned int i = 0; i < ls->getSize(); i++, values++)
126  vars[i] = values->front();
127 
128  ls->add(vars, target, weight);
129 }
130 
131 void ProcLinear::trainEnd()
132 {
133  switch(iteration) {
134  case ITER_FILL:
135  vars.clear();
136  ls->calculate();
137 
138  iteration = ITER_DONE;
139  trained = true;
140  break;
141  default:
142  /* shut up */;
143  }
144 }
145 
146 void *ProcLinear::requestObject(const std::string &name) const
147 {
148  if (name == "linearAnalyzer")
149  return static_cast<void*>(ls.get());
150 
151  return nullptr;
152 }
153 
154 bool ProcLinear::load()
155 {
156  std::string filename = trainer->trainFileName(this, "xml");
157  if (!exists(filename))
158  return false;
159 
160  XMLDocument xml(filename);
161  DOMElement *elem = xml.getRootNode();
162  if (std::strcmp(XMLSimpleStr(elem->getNodeName()), "ProcLinear") != 0)
163  throw cms::Exception("ProcLinear")
164  << "XML training data file has bad root node."
165  << std::endl;
166 
167  DOMNode *node = elem->getFirstChild();
168  while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
169  node = node->getNextSibling();
170 
171  if (!node)
172  throw cms::Exception("ProcLinear")
173  << "Train data file empty." << std::endl;
174 
175  ls->load(static_cast<DOMElement*>(node));
176 
177  node = elem->getNextSibling();
178  while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
179  node = node->getNextSibling();
180 
181  if (node)
182  throw cms::Exception("ProcLinear")
183  << "Train data file contains superfluous tags."
184  << std::endl;
185 
186  iteration = ITER_DONE;
187  trained = true;
188  return true;
189 }
190 
191 void ProcLinear::save()
192 {
193  XMLDocument xml(trainer->trainFileName(this, "xml"), true);
194  DOMDocument *doc = xml.createDocument("ProcLinear");
195 
196  xml.getRootNode()->appendChild(ls->save(doc));
197 }
198 
199 } // anonymous namespace
size
Write out results.
Definition: weight.py:1
template to generate a registry singleton for a type.
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:31
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
static bool hasAttribute(XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *elem, const char *name)
Definition: XMLDocument.cc:305
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:18
def ls(path, rec=False)
Definition: eostools.py:348
def load(fileName)
Definition: svgfig.py:546
vars
Definition: DeepTauId.cc:77
static Interceptor::Registry registry("Interceptor")
save
Definition: cuy.py:1164