5 #include <xercesc/dom/DOM.hpp> 32 ~ProcMatrix()
override;
34 void configure(DOMElement *
elem)
override;
37 void trainBegin()
override;
38 void trainData(
const std::vector<double> *
values,
40 void trainEnd()
override;
54 typedef std::pair<unsigned int, double> Rank;
56 std::vector<Rank>
ranking()
const;
58 std::unique_ptr<LeastSquares> lsSignal, lsBackground;
59 std::unique_ptr<LeastSquares>
ls;
60 std::vector<double>
vars;
69 ProcMatrix::ProcMatrix(
const char *
name,
const AtomicId *
id,
77 ProcMatrix::~ProcMatrix()
81 void ProcMatrix::configure(DOMElement *
elem)
85 DOMNode *node = elem->getFirstChild();
86 while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
87 node = node->getNextSibling();
92 if (std::strcmp(
XMLSimpleStr(node->getNodeName()),
"fill") != 0)
94 <<
"Expected fill tag in config section." 97 elem =
static_cast<DOMElement*
>(node);
100 XMLDocument::readAttribute<bool>(
elem,
"signal",
false);
102 XMLDocument::readAttribute<bool>(
elem,
"background",
false);
104 XMLDocument::readAttribute<bool>(
elem,
"normalize",
false);
106 doRanking = XMLDocument::readAttribute<bool>(
elem,
"ranking",
false);
108 fillSignal = fillBackground = doNormalization =
true;
110 if (doNormalization && fillSignal && fillBackground) {
115 node = node->getNextSibling();
116 while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
117 node = node->getNextSibling();
121 <<
"Superfluous tags in config section." 124 if (!fillSignal && !fillBackground)
126 <<
"Filling neither background nor signal in config." 137 unsigned int n =
ls->getSize();
143 for(
unsigned int i = 0;
i <
n;
i++)
144 for(
unsigned int j = 0; j <
n; j++)
150 void ProcMatrix::trainBegin()
153 vars.resize(
ls->getSize());
156 void ProcMatrix::trainData(
const std::vector<double> *
values,
162 if (!(target ? fillSignal : fillBackground))
169 for(
unsigned int i = 0;
i < ls->
getSize();
i++, values++) {
173 << (
const char*)getInputs().get()[
i]->getName()
174 <<
"\" is not set in ProcMatrix trainer." 176 vars[
i] = values->front();
182 void ProcMatrix::trainEnd()
187 if (lsSignal.get()) {
188 unsigned int n = ls->
getSize();
189 double weight = lsSignal->getCoefficients()
192 ls->
add(*lsSignal, 1.0 / weight);
195 if (lsBackground.get()) {
196 unsigned int n = ls->
getSize();
197 double weight = lsBackground->getCoefficients()
200 ls->
add(*lsBackground, 1.0 / weight);
201 lsBackground.reset();
216 histo->SetNameTitle(
"CorrMatrix",
217 (fillSignal && fillBackground)
218 ?
"correlation matrix (signal + background)" 219 : (fillSignal ?
"correlation matrix (signal)" 220 :
"correlation matrix (background)"));
222 std::vector<SourceVariable*>
inputs = getInputs().get();
223 for(std::vector<SourceVariable*>::const_iterator iter =
224 inputs.begin(); iter != inputs.end(); ++iter) {
226 unsigned int idx = iter - inputs.begin();
233 histo->GetXaxis()->SetBinLabel(idx + 1, name.c_str());
234 histo->GetYaxis()->SetBinLabel(idx + 1, name.c_str());
235 histo->GetXaxis()->SetBinLabel(inputs.size() + 1,
237 histo->GetYaxis()->SetBinLabel(inputs.size() + 1,
240 histo->LabelsOption(
"d");
241 histo->SetMinimum(-1.0);
242 histo->SetMaximum(+1.0);
247 std::vector<Rank> ranks =
ranking();
248 TVectorD rankVector(ranks.size());
249 for(
unsigned int i = 0;
i < ranks.size();
i++)
251 TH1F *rank =
monitoring->book<TH1F>(
"Ranking", rankVector);
252 rank->SetNameTitle(
"Ranking",
"variable ranking");
253 rank->SetYTitle(
"correlation to target");
254 for(
unsigned int i = 0;
i < ranks.size();
i++) {
255 unsigned int v = ranks[
i].first;
261 rank->GetXaxis()->SetBinLabel(
i + 1, name.c_str());
266 void *ProcMatrix::requestObject(
const std::string &name)
const 268 if (name ==
"linearAnalyzer")
269 return static_cast<void*
>(ls.get());
277 if (!exists(filename))
281 DOMElement *elem = xml.getRootNode();
282 if (std::strcmp(
XMLSimpleStr(elem->getNodeName()),
"ProcMatrix") != 0)
284 <<
"XML training data file has bad root node." 287 DOMNode *node = elem->getFirstChild();
288 while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
289 node = node->getNextSibling();
293 <<
"Train data file empty." << std::endl;
295 ls->
load(static_cast<DOMElement*>(node));
297 node = elem->getNextSibling();
298 while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
299 node = node->getNextSibling();
303 <<
"Train data file contains superfluous tags." 313 XMLDocument xml(trainer->trainFileName(
this,
"xml"),
true);
314 DOMDocument *
doc = xml.createDocument(
"ProcMatrix");
316 xml.getRootNode()->appendChild(ls->
save(doc));
319 void maskLine(TMatrixDSym &
m,
unsigned int line)
321 unsigned int n = m.GetNrows();
322 for(
unsigned int i = 0;
i <
n;
i++)
323 m(
i, line) =
m(line,
i) = 0.;
327 void restoreLine(TMatrixDSym &m, TMatrixDSym &
o,
unsigned int line)
329 unsigned int n = m.GetNrows();
330 for(
unsigned int i = 0;
i <
n;
i++) {
331 m(
i, line) =
o(
i, line);
332 m(line,
i) =
o(line,
i);
336 double targetCorrelation(
const TMatrixDSym &coeffs,
337 const std::vector<bool> &use)
339 unsigned int n = coeffs.GetNrows() - 2;
341 TVectorD
weights = LeastSquares::solveFisher(coeffs);
342 weights.ResizeTo(n + 2);
343 weights[n + 1] = weights[
n];
348 double v3 = coeffs(n, n);
349 double N = coeffs(n + 1, n + 1);
351 for(
unsigned int i = 0;
i < n + 2;
i++) {
352 if (
i < n && !use[n])
354 double w = weights[
i];
355 for(
unsigned int j = 0; j < n + 2; j++) {
356 if (
i < n && !use[n])
358 v1 += w * weights[j] * coeffs(
i, j);
360 v2 += w * coeffs(
i, n);
361 M += w * coeffs(
i, n + 1);
364 double c1 = v1 * N - M * M;
365 double c2 = v2 * N - M * coeffs(n + 1, n);
366 double c3 = v3 * N - coeffs(n + 1, n) * coeffs(n + 1, n);
369 return (c > 1.0
e-9) ? c2 /
std::sqrt(c) : 0.0;
375 unsigned int n = coeffs.GetNrows() - 2;
377 typedef std::pair<unsigned int, double> Rank;
379 std::vector<bool> use(n,
true);
381 double corr = targetCorrelation(coeffs, use);
383 for(
unsigned int nVars = n; nVars > 1; nVars--) {
384 double bestCorr = -99999.0;
385 unsigned int bestIdx =
n;
386 TMatrixDSym origCoeffs = coeffs;
388 for(
unsigned int i = 0;
i <
n;
i++) {
394 double newCorr = targetCorrelation(coeffs, use);
396 restoreLine(coeffs, origCoeffs, i);
398 if (newCorr > bestCorr) {
404 ranking.push_back(Rank(bestIdx, corr));
406 use[bestIdx] =
false;
407 maskLine(coeffs, bestIdx);
410 for(
unsigned int i = 0; i <
n; i++)
412 ranking.push_back(Rank(i, corr));
U second(std::pair< T, U > const &p)
MVATrainerComputer * calib
def elem(elemtype, innerHTML='', html_class='', kwargs)