CMS 3D CMS Logo

ProcMatrix.cc
Go to the documentation of this file.
1 #include <cstring>
2 #include <vector>
3 #include <memory>
4 
5 #include <xercesc/dom/DOM.hpp>
6 
7 #include <TMatrixD.h>
8 #include <TMatrixF.h>
9 #include <TH2.h>
10 
12 
14 
19 
21 
22 using namespace PhysicsTools;
23 
24 namespace { // anonymous
25 
26 class ProcMatrix : public TrainProcessor {
27  public:
29 
30  ProcMatrix(const char *name, const AtomicId *id,
31  MVATrainer *trainer);
32  ~ProcMatrix() override;
33 
34  void configure(DOMElement *elem) override;
35  Calibration::VarProcessor *getCalibration() const override;
36 
37  void trainBegin() override;
38  void trainData(const std::vector<double> *values,
39  bool target, double weight) override;
40  void trainEnd() override;
41 
42  bool load() override;
43  void save() override;
44 
45  protected:
46  void *requestObject(const std::string &name) const override;
47 
48  private:
49  enum Iteration {
50  ITER_FILL,
51  ITER_DONE
52  } iteration;
53 
54  typedef std::pair<unsigned int, double> Rank;
55 
56  std::vector<Rank> ranking() const;
57 
58  std::unique_ptr<LeastSquares> lsSignal, lsBackground;
59  std::unique_ptr<LeastSquares> ls;
60  std::vector<double> vars;
61  bool fillSignal;
62  bool fillBackground;
63  bool doNormalization;
64  bool doRanking;
65 };
66 
67 ProcMatrix::Registry registry("ProcMatrix");
68 
69 ProcMatrix::ProcMatrix(const char *name, const AtomicId *id,
70  MVATrainer *trainer) :
71  TrainProcessor(name, id, trainer),
72  iteration(ITER_FILL), fillSignal(true), fillBackground(true),
73  doRanking(false)
74 {
75 }
76 
77 ProcMatrix::~ProcMatrix()
78 {
79 }
80 
81 void ProcMatrix::configure(DOMElement *elem)
82 {
83  ls.reset(new LeastSquares(getInputs().size()));
84 
85  DOMNode *node = elem->getFirstChild();
86  while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
87  node = node->getNextSibling();
88 
89  if (!node)
90  return;
91 
92  if (std::strcmp(XMLSimpleStr(node->getNodeName()), "fill") != 0)
93  throw cms::Exception("ProcMatrix")
94  << "Expected fill tag in config section."
95  << std::endl;
96 
97  elem = static_cast<DOMElement*>(node);
98 
99  fillSignal =
100  XMLDocument::readAttribute<bool>(elem, "signal", false);
101  fillBackground =
102  XMLDocument::readAttribute<bool>(elem, "background", false);
104  XMLDocument::readAttribute<bool>(elem, "normalize", false);
105 
106  doRanking = XMLDocument::readAttribute<bool>(elem, "ranking", false);
107  if (doRanking)
108  fillSignal = fillBackground = doNormalization = true;
109 
110  if (doNormalization && fillSignal && fillBackground) {
111  lsSignal.reset(new LeastSquares(getInputs().size()));
112  lsBackground.reset(new LeastSquares(getInputs().size()));
113  }
114 
115  node = node->getNextSibling();
116  while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
117  node = node->getNextSibling();
118 
119  if (node)
120  throw cms::Exception("ProcMatrix")
121  << "Superfluous tags in config section."
122  << std::endl;
123 
124  if (!fillSignal && !fillBackground)
125  throw cms::Exception("ProcMatrix")
126  << "Filling neither background nor signal in config."
127  << std::endl;
128 }
129 
130 Calibration::VarProcessor *ProcMatrix::getCalibration() const
131 {
132  if (doRanking)
133  return nullptr;
134 
136 
137  unsigned int n = ls->getSize();
138  const TMatrixD &rotation = ls->getRotation();
139 
140  calib->matrix.rows = n;
141  calib->matrix.columns = n;
142 
143  for(unsigned int i = 0; i < n; i++)
144  for(unsigned int j = 0; j < n; j++)
145  calib->matrix.elements.push_back(rotation(j, i));
146 
147  return calib;
148 }
149 
150 void ProcMatrix::trainBegin()
151 {
152  if (iteration == ITER_FILL)
153  vars.resize(ls->getSize());
154 }
155 
156 void ProcMatrix::trainData(const std::vector<double> *values,
157  bool target, double weight)
158 {
159  if (iteration != ITER_FILL)
160  return;
161 
162  if (!(target ? fillSignal : fillBackground))
163  return;
164 
165  LeastSquares *ls = target ? lsSignal.get() : lsBackground.get();
166  if (!ls)
167  ls = this->ls.get();
168 
169  for(unsigned int i = 0; i < ls->getSize(); i++, values++) {
170  if (values->empty())
171  throw cms::Exception("ProcMatrix")
172  << "Variable \""
173  << (const char*)getInputs().get()[i]->getName()
174  << "\" is not set in ProcMatrix trainer."
175  << std::endl;
176  vars[i] = values->front();
177  }
178 
179  ls->add(vars, target, weight);
180 }
181 
182 void ProcMatrix::trainEnd()
183 {
184  switch(iteration) {
185  case ITER_FILL:
186  vars.clear();
187  if (lsSignal.get()) {
188  unsigned int n = ls->getSize();
189  double weight = lsSignal->getCoefficients()
190  (n + 1, n + 1);
191  if (weight > 1.0e-9)
192  ls->add(*lsSignal, 1.0 / weight);
193  lsSignal.reset();
194  }
195  if (lsBackground.get()) {
196  unsigned int n = ls->getSize();
197  double weight = lsBackground->getCoefficients()
198  (n + 1, n + 1);
199  if (weight > 1.0e-9)
200  ls->add(*lsBackground, 1.0 / weight);
201  lsBackground.reset();
202  }
203  ls->calculate();
204 
205  iteration = ITER_DONE;
206  trained = true;
207  break;
208 
209  default:
210  /* shut up */;
211  }
212 
213  if (iteration == ITER_DONE && monitoring) {
214  TMatrixF matrix(ls->getCorrelations());
215  TH2F *histo = monitoring->book<TH2F>("CorrMatrix", matrix);
216  histo->SetNameTitle("CorrMatrix",
217  (fillSignal && fillBackground)
218  ? "correlation matrix (signal + background)"
219  : (fillSignal ? "correlation matrix (signal)"
220  : "correlation matrix (background)"));
221 
222  std::vector<SourceVariable*> inputs = getInputs().get();
223  for(std::vector<SourceVariable*>::const_iterator iter =
224  inputs.begin(); iter != inputs.end(); ++iter) {
225 
226  unsigned int idx = iter - inputs.begin();
227  SourceVariable *var = *iter;
228  std::string name =
229  (const char*)var->getSource()->getName()
230  + std::string("_")
231  + (const char*)var->getName();
232 
233  histo->GetXaxis()->SetBinLabel(idx + 1, name.c_str());
234  histo->GetYaxis()->SetBinLabel(idx + 1, name.c_str());
235  histo->GetXaxis()->SetBinLabel(inputs.size() + 1,
236  "target");
237  histo->GetYaxis()->SetBinLabel(inputs.size() + 1,
238  "target");
239  }
240  histo->LabelsOption("d");
241  histo->SetMinimum(-1.0);
242  histo->SetMaximum(+1.0);
243 
244  if (!doRanking)
245  return;
246 
247  std::vector<Rank> ranks = ranking();
248  TVectorD rankVector(ranks.size());
249  for(unsigned int i = 0; i < ranks.size(); i++)
250  rankVector[i] = ranks[i].second;
251  TH1F *rank = monitoring->book<TH1F>("Ranking", rankVector);
252  rank->SetNameTitle("Ranking", "variable ranking");
253  rank->SetYTitle("correlation to target");
254  for(unsigned int i = 0; i < ranks.size(); i++) {
255  unsigned int v = ranks[i].first;
257  SourceVariable *var = inputs[v];
258  name = (const char*)var->getSource()->getName()
259  + std::string("_")
260  + (const char*)var->getName();
261  rank->GetXaxis()->SetBinLabel(i + 1, name.c_str());
262  }
263  }
264 }
265 
266 void *ProcMatrix::requestObject(const std::string &name) const
267 {
268  if (name == "linearAnalyzer")
269  return static_cast<void*>(ls.get());
270 
271  return nullptr;
272 }
273 
274 bool ProcMatrix::load()
275 {
276  std::string filename = trainer->trainFileName(this, "xml");
277  if (!exists(filename))
278  return false;
279 
280  XMLDocument xml(filename);
281  DOMElement *elem = xml.getRootNode();
282  if (std::strcmp(XMLSimpleStr(elem->getNodeName()), "ProcMatrix") != 0)
283  throw cms::Exception("ProcMatrix")
284  << "XML training data file has bad root node."
285  << std::endl;
286 
287  DOMNode *node = elem->getFirstChild();
288  while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
289  node = node->getNextSibling();
290 
291  if (!node)
292  throw cms::Exception("ProcMatrix")
293  << "Train data file empty." << std::endl;
294 
295  ls->load(static_cast<DOMElement*>(node));
296 
297  node = elem->getNextSibling();
298  while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
299  node = node->getNextSibling();
300 
301  if (node)
302  throw cms::Exception("ProcMatrix")
303  << "Train data file contains superfluous tags."
304  << std::endl;
305 
306  iteration = ITER_DONE;
307  trained = true;
308  return true;
309 }
310 
311 void ProcMatrix::save()
312 {
313  XMLDocument xml(trainer->trainFileName(this, "xml"), true);
314  DOMDocument *doc = xml.createDocument("ProcMatrix");
315 
316  xml.getRootNode()->appendChild(ls->save(doc));
317 }
318 
319 void maskLine(TMatrixDSym &m, unsigned int line)
320 {
321  unsigned int n = m.GetNrows();
322  for(unsigned int i = 0; i < n; i++)
323  m(i, line) = m(line, i) = 0.;
324  m(line, line) = 1.;
325 }
326 
327 void restoreLine(TMatrixDSym &m, TMatrixDSym &o, unsigned int line)
328 {
329  unsigned int n = m.GetNrows();
330  for(unsigned int i = 0; i < n; i++) {
331  m(i, line) = o(i, line);
332  m(line, i) = o(line, i);
333  }
334 }
335 
336 double targetCorrelation(const TMatrixDSym &coeffs,
337  const std::vector<bool> &use)
338 {
339  unsigned int n = coeffs.GetNrows() - 2;
340 
341  TVectorD weights = LeastSquares::solveFisher(coeffs);
342  weights.ResizeTo(n + 2);
343  weights[n + 1] = weights[n];
344  weights[n] = 0.;
345 
346  double v1 = 0.;
347  double v2 = 0.;
348  double v3 = coeffs(n, n);
349  double N = coeffs(n + 1, n + 1);
350  double M = 0.;
351  for(unsigned int i = 0; i < n + 2; i++) {
352  if (i < n && !use[n])
353  continue;
354  double w = weights[i];
355  for(unsigned int j = 0; j < n + 2; j++) {
356  if (i < n && !use[n])
357  continue;
358  v1 += w * weights[j] * coeffs(i, j);
359  }
360  v2 += w * coeffs(i, n);
361  M += w * coeffs(i, n + 1);
362  }
363 
364  double c1 = v1 * N - M * M;
365  double c2 = v2 * N - M * coeffs(n + 1, n);
366  double c3 = v3 * N - coeffs(n + 1, n) * coeffs(n + 1, n);
367 
368  double c = c1 * c3;
369  return (c > 1.0e-9) ? c2 / std::sqrt(c) : 0.0;
370 }
371 
372 std::vector<ProcMatrix::Rank> ProcMatrix::ranking() const
373 {
374  TMatrixDSym coeffs = ls->getCoefficients();
375  unsigned int n = coeffs.GetNrows() - 2;
376 
377  typedef std::pair<unsigned int, double> Rank;
378  std::vector<Rank> ranking;
379  std::vector<bool> use(n, true);
380 
381  double corr = targetCorrelation(coeffs, use);
382 
383  for(unsigned int nVars = n; nVars > 1; nVars--) {
384  double bestCorr = -99999.0;
385  unsigned int bestIdx = n;
386  TMatrixDSym origCoeffs = coeffs;
387 
388  for(unsigned int i = 0; i < n; i++) {
389  if (!use[i])
390  continue;
391 
392  use[i] = false;
393  maskLine(coeffs, i);
394  double newCorr = targetCorrelation(coeffs, use);
395  use[i] = true;
396  restoreLine(coeffs, origCoeffs, i);
397 
398  if (newCorr > bestCorr) {
399  bestCorr = newCorr;
400  bestIdx = i;
401  }
402  }
403 
404  ranking.push_back(Rank(bestIdx, corr));
405  corr = bestCorr;
406  use[bestIdx] = false;
407  maskLine(coeffs, bestIdx);
408  }
409 
410  for(unsigned int i = 0; i < n; i++)
411  if (use[i])
412  ranking.push_back(Rank(i, corr));
413 
414  return ranking;
415 }
416 
417 } // anonymous namespace
size
Write out results.
Source * getSource() const
const double w
Definition: UKUtility.cc:23
XERCES_CPP_NAMESPACE_QUALIFIER DOMElement * save(XERCES_CPP_NAMESPACE_QUALIFIER DOMDocument *doc) const
void add(const std::vector< double > &values, double dest, double weight=1.0)
Definition: LeastSquares.cc:32
unsigned int getSize() const
Definition: LeastSquares.h:29
Definition: weight.py:1
template to generate a registry singleton for a type.
void load(XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *elem)
std::vector< double > elements
Definition: MVAComputer.h:43
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:31
const TMatrixDSym & getCoefficients() const
Definition: LeastSquares.h:30
U second(std::pair< T, U > const &p)
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
const AtomicId getName() const
Definition: Variable.h:143
T sqrt(T t)
Definition: SSEVec.h:18
JetCorrectorParameters corr
Definition: classes.h:5
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:18
def ls(path, rec=False)
Definition: eostools.py:348
#define N
Definition: blowfish.cc:9
def load(fileName)
Definition: svgfig.py:546
AtomicId getName() const
Definition: Source.h:19
vars
Definition: DeepTauId.cc:77
static Interceptor::Registry registry("Interceptor")
const TMatrixDSym & getCorrelations() const
Definition: LeastSquares.h:32
save
Definition: cuy.py:1164