CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
ProcLinear.cc
Go to the documentation of this file.
1 #include <iostream>
2 #include <vector>
3 #include <memory>
4 
5 #include <xercesc/dom/DOM.hpp>
6 
8 
10 
15 
16 XERCES_CPP_NAMESPACE_USE
17 
18 using namespace PhysicsTools;
19 
20 namespace { // anonymous
21 
22 class ProcLinear : public TrainProcessor {
23  public:
25 
26  ProcLinear(const char *name, const AtomicId *id,
27  MVATrainer *trainer);
28  virtual ~ProcLinear();
29 
30  virtual void configure(DOMElement *elem);
31  virtual Calibration::VarProcessor *getCalibration() const;
32 
33  virtual void trainBegin();
34  virtual void trainData(const std::vector<double> *values,
35  bool target, double weight);
36  virtual void trainEnd();
37 
38  virtual bool load();
39  virtual void save();
40 
41  protected:
42  virtual void *requestObject(const std::string &name) const;
43 
44  private:
45  enum Iteration {
46  ITER_FILL,
47  ITER_DONE
48  } iteration;
49 
50  std::auto_ptr<LeastSquares> ls;
51  std::vector<double> vars;
52 };
53 
54 static ProcLinear::Registry registry("ProcLinear");
55 
56 ProcLinear::ProcLinear(const char *name, const AtomicId *id,
57  MVATrainer *trainer) :
58  TrainProcessor(name, id, trainer),
59  iteration(ITER_FILL)
60 {
61 }
62 
63 ProcLinear::~ProcLinear()
64 {
65 }
66 
67 void ProcLinear::configure(DOMElement *elem)
68 {
69  ls = std::auto_ptr<LeastSquares>(new LeastSquares(getInputs().size()));
70 }
71 
72 Calibration::VarProcessor *ProcLinear::getCalibration() const
73 {
75 
76  calib->coeffs = ls->getWeights();
77  calib->offset = ls->getConstant();
78 
79  return calib;
80 }
81 
82 void ProcLinear::trainBegin()
83 {
84  if (iteration == ITER_FILL)
85  vars.resize(ls->getSize());
86 }
87 
88 void ProcLinear::trainData(const std::vector<double> *values,
89  bool target, double weight)
90 {
91  if (iteration != ITER_FILL)
92  return;
93 
94  for(unsigned int i = 0; i < ls->getSize(); i++, values++)
95  vars[i] = values->front();
96 
97  ls->add(vars, target, weight);
98 }
99 
100 void ProcLinear::trainEnd()
101 {
102  switch(iteration) {
103  case ITER_FILL:
104  vars.clear();
105  ls->calculate();
106 
107  iteration = ITER_DONE;
108  trained = true;
109  break;
110  default:
111  /* shut up */;
112  }
113 }
114 
115 void *ProcLinear::requestObject(const std::string &name) const
116 {
117  if (name == "linearAnalyzer")
118  return static_cast<void*>(ls.get());
119 
120  return 0;
121 }
122 
123 bool ProcLinear::load()
124 {
125  std::string filename = trainer->trainFileName(this, "xml");
126  if (!exists(filename))
127  return false;
128 
129  XMLDocument xml(filename);
130  DOMElement *elem = xml.getRootNode();
131  if (std::strcmp(XMLSimpleStr(elem->getNodeName()), "ProcLinear") != 0)
132  throw cms::Exception("ProcLinear")
133  << "XML training data file has bad root node."
134  << std::endl;
135 
136  DOMNode *node = elem->getFirstChild();
137  while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
138  node = node->getNextSibling();
139 
140  if (!node)
141  throw cms::Exception("ProcLinear")
142  << "Train data file empty." << std::endl;
143 
144  ls->load(static_cast<DOMElement*>(node));
145 
146  node = elem->getNextSibling();
147  while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
148  node = node->getNextSibling();
149 
150  if (node)
151  throw cms::Exception("ProcLinear")
152  << "Train data file contains superfluous tags."
153  << std::endl;
154 
155  iteration = ITER_DONE;
156  trained = true;
157  return true;
158 }
159 
160 void ProcLinear::save()
161 {
162  XMLDocument xml(trainer->trainFileName(this, "xml"), true);
163  DOMDocument *doc = xml.createDocument("ProcLinear");
164 
165  xml.getRootNode()->appendChild(ls->save(doc));
166 }
167 
168 } // anonymous namespace
int i
Definition: DBlmapReader.cc:9
detail::ThreadSafeRegistry< ParameterSetID, ParameterSet, ProcessParameterSetIDCache > Registry
Definition: Registry.h:37
tuple node
Definition: Node.py:50
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:32
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
tuple iteration
Definition: align_cfg.py:5
def load
Definition: svgfig.py:546
tuple doc
Definition: asciidump.py:381
tuple filename
Definition: lut2db_cfg.py:20
template to generate a registry singleton for a type.
static Interceptor::Registry registry("Interceptor")
tuple size
Write out results.