CMS 3D CMS Logo

MVATrainer.cc
Go to the documentation of this file.
1 #include <cassert>
2 #include <functional>
3 #include <ext/functional>
4 #include <algorithm>
5 #include <iostream>
6 #include <cstdarg>
7 #include <cstring>
8 #include <cstdio>
9 #include <string>
10 #include <vector>
11 #include <memory>
12 #include <map>
13 #include <set>
14 
15 #include <xercesc/dom/DOM.hpp>
16 
17 #include <TRandom.h>
18 
22 
27 
37 
39 
40 namespace PhysicsTools {
41 
42 namespace { // anonymous
43  class MVATrainerComputer;
44 
45  class BaseInterceptor : public Calibration::Interceptor {
46  public:
47  BaseInterceptor() : calib(nullptr) {}
48  ~BaseInterceptor() override {}
49 
50  inline void setCalibration(MVATrainerComputer *calib)
51  { this->calib = calib; }
52 
53  std::vector<Variable::Flags>
54  configure(const MVAComputer *computer, unsigned int n,
55  const std::vector<Variable::Flags> &flags) override = 0;
56 
57  double
58  intercept(const std::vector<double> *values) const override = 0;
59 
60  virtual void init() {}
61  virtual void finish(bool save) {}
62 
63  protected:
64  MVATrainerComputer *calib;
65  };
66 
67  class InitInterceptor : public BaseInterceptor {
68  public:
69  InitInterceptor() {}
70  ~InitInterceptor() override {}
71 
72  std::vector<Variable::Flags>
73  configure(const MVAComputer *computer, unsigned int n,
74  const std::vector<Variable::Flags> &flags) override;
75 
76  double
77  intercept(const std::vector<double> *values) const override;
78  };
79 
80  class TrainInterceptor : public BaseInterceptor {
81  public:
82  TrainInterceptor(TrainProcessor *proc) : proc(proc) {}
83  ~TrainInterceptor() override {}
84 
85  inline TrainProcessor *getProcessor() const { return proc; }
86 
87  std::vector<Variable::Flags>
88  configure(const MVAComputer *computer, unsigned int n,
89  const std::vector<Variable::Flags> &flags) override;
90 
91  double
92  intercept(const std::vector<double> *values) const override;
93 
94  void init() override;
95  void finish(bool save) override;
96 
97  private:
98  unsigned int targetIdx;
99  unsigned int weightIdx;
100  mutable std::vector<std::vector<double> > tmp;
101  TrainProcessor *const proc;
102  };
103 
104  class MVATrainerComputer : public TrainMVAComputerCalibration {
105  public:
106  typedef std::pair<unsigned int, BaseInterceptor*> Interceptor;
107 
108  MVATrainerComputer(const std::vector<Interceptor>
109  &interceptors,
110  bool autoSave, UInt_t seed, double split);
111 
112  ~MVATrainerComputer() override;
113 
114  std::vector<Calibration::VarProcessor*>
115  getProcessors() const override;
116  void initFlags(std::vector<Variable::Flags>
117  &flags) const override;
118 
119  void configured(BaseInterceptor *interceptor) const;
120  void next();
121  void done();
122 
123  inline void addFlag(Variable::Flags flag)
124  { flags.push_back(flag); }
125 
126  inline bool useForTraining() const { return splitResult; }
127  inline bool useForTesting() const
128  { return split <= 0.0 || !splitResult; }
129 
130  inline bool isConfigured() const
131  { return nConfigured == interceptors.size(); }
132 
133  private:
134  std::vector<Interceptor> interceptors;
135  std::vector<Variable::Flags> flags;
136  mutable unsigned int nConfigured;
138  TRandom random;
139  double split;
141  };
142 
143  // useful litte helpers
144 
145  template<typename T>
146  struct deleter : public std::unary_function<T*, void> {
147  inline void operator() (T *ptr) const { delete ptr; }
148  };
149 
150  template<typename T>
151  struct auto_cleaner {
152  inline ~auto_cleaner()
153  { std::for_each(clean.begin(), clean.end(), deleter<T>()); }
154 
155  inline void add(T *ptr) { clean.push_back(ptr); }
156  std::vector<T*> clean;
157  };
158 } // anonymous namespace
159 
160 static std::string stdStringVPrintf(const char *format, std::va_list va)
161 {
162  unsigned int size = std::min<unsigned int>(128, std::strlen(format));
163  char *buffer = new char[size];
164  for(;;) {
165  int n = std::vsnprintf(buffer, size, format, va);
166  if (n >= 0 && (unsigned int)n < size)
167  break;
168 
169  if (n >= 0)
170  size = n + 1;
171  else
172  size *= 2;
173 
174  delete[] buffer;
175  buffer = new char[size];
176  }
177 
178  std::string result(buffer);
179  delete[] buffer;
180  return result;
181 }
182 
183 static std::string stdStringPrintf(const char *format, ...)
184 {
185  std::va_list va;
186  va_start(va, format);
187  std::string result = stdStringVPrintf(format, va);
188  va_end(va);
189  return result;
190 }
191 
192 // implementation for InitInterceptor
193 
194 std::vector<Variable::Flags>
195 InitInterceptor::configure(const MVAComputer *computer, unsigned int n,
196  const std::vector<Variable::Flags> &flags)
197 {
198  calib->configured(this);
199  return std::vector<Variable::Flags>(n, Variable::FLAG_ALL);
200 }
201 
202 double
203 InitInterceptor::intercept(const std::vector<double> *values) const
204 {
205  calib->next();
206  return 0.0;
207 }
208 
209 // implementation for TrainInterceptor
210 
211 std::vector<Variable::Flags>
212 TrainInterceptor::configure(const MVAComputer *computer, unsigned int n,
213  const std::vector<Variable::Flags> &flags)
214 {
215  const SourceVariableSet &inputSet =
216  const_cast<const TrainProcessor*>(proc)->getInputs();
219 
220  std::vector<SourceVariable*> inputs = inputSet.get(true);
221 
222  std::vector<SourceVariable*>::const_iterator pos;
223  pos = std::find(inputs.begin(), inputs.end(), target);
224  assert(pos != inputs.end());
225  targetIdx = pos - inputs.begin();
226  pos = std::find(inputs.begin(), inputs.end(), weight);
227  assert(pos != inputs.end());
228  weightIdx = pos - inputs.begin();
229 
230  calib->configured(this);
231 
232  std::vector<Variable::Flags> result = flags;
233  if (targetIdx < weightIdx) {
234  result.erase(result.begin() + weightIdx);
235  result.erase(result.begin() + targetIdx);
236  } else {
237  result.erase(result.begin() + targetIdx);
238  result.erase(result.begin() + weightIdx);
239  }
240 
241  proc->passFlags(result);
242 
243  result.clear();
244  result.resize(n, proc->getDefaultFlags());
245  result[targetIdx] = Variable::FLAG_NONE;
247 
248  if (targetIdx >= 2 || weightIdx >= 2)
249  tmp.resize(n - 2);
250 
251  return result;
252 }
253 
255 {
256  edm::LogInfo("MVATrainer")
257  << "TrainProcessor \"" << (const char*)proc->getName()
258  << "\" training iteration starting...";
259 
260  proc->doTrainBegin();
261 }
262 
263 double
264 TrainInterceptor::intercept(const std::vector<double> *values) const
265 {
266  if (values[targetIdx].size() != 1) {
267  if (values[targetIdx].empty())
268  throw cms::Exception("MVATrainer")
269  << "Trainer input lacks target variable."
270  << std::endl;
271  else
272  throw cms::Exception("MVATrainer")
273  << "Multiple targets supplied in input."
274  << std::endl;
275  }
276  double target = values[targetIdx].front();
277 
278  double weight = 1.0;
279  if (values[weightIdx].size() > 1)
280  throw cms::Exception("MVATrainer")
281  << "Multiple weights supplied in input."
282  << std::endl;
283  else if (values[weightIdx].size() == 1)
284  weight = values[weightIdx].front();
285 
286  if (tmp.empty())
287  proc->doTrainData(values + 2, target > 0.5, weight,
288  calib->useForTraining(),
289  calib->useForTesting());
290  else {
291  std::vector<std::vector<double> >::iterator pos = tmp.begin();
292  for(unsigned int i = 0; pos != tmp.end(); i++)
293  if (i != targetIdx && i != weightIdx)
294  *pos++ = values[i];
295 
296  proc->doTrainData(&tmp.front(), target > 0.5, weight,
297  calib->useForTraining(),
298  calib->useForTesting());
299  }
300 
301  return target;
302 }
303 
304 void TrainInterceptor::finish(bool save)
305 {
306  proc->doTrainEnd();
307 
308  edm::LogInfo("MVATrainer")
309  << "... processor \"" << (const char*)proc->getName()
310  << "\" training iteration done.";
311 
312  if (proc->isTrained()) {
313  edm::LogInfo("MVATrainer")
314  << "* Completed training of \""
315  << (const char*)proc->getName() << "\".";
316 
317  if (save)
318  proc->save();
319  }
320 }
321 
322 // implementation for MVATrainerComputer
323 
324 MVATrainerComputer::MVATrainerComputer(const std::vector<Interceptor>
325  &interceptors, bool autoSave,
326  UInt_t seed, double split) :
327  interceptors(interceptors), nConfigured(0), doAutoSave(autoSave),
328  random(seed), split(split)
329 {
330  for(std::vector<Interceptor>::const_iterator iter =
331  interceptors.begin(); iter != interceptors.end(); ++iter)
332  iter->second->setCalibration(this);
333 }
334 
335 MVATrainerComputer::~MVATrainerComputer()
336 {
337  done();
338 
339  for(std::vector<Interceptor>::const_iterator iter =
340  interceptors.begin(); iter != interceptors.end(); ++iter)
341  delete iter->second;
342 }
343 
344 std::vector<Calibration::VarProcessor*>
345 MVATrainerComputer::getProcessors() const
346 {
347  std::vector<Calibration::VarProcessor*> processors =
349 
350  for(std::vector<Interceptor>::const_iterator iter =
351  interceptors.begin(); iter != interceptors.end(); ++iter)
352 
353  processors.insert(processors.begin() + iter->first,
354  1, iter->second);
355 
356  return processors;
357 }
358 
359 void MVATrainerComputer::initFlags(std::vector<Variable::Flags> &flags) const
360 {
361  assert(flags.size() == this->flags.size());
362  flags = this->flags;
363 }
364 
365 void MVATrainerComputer::configured(BaseInterceptor *interceptor) const
366 {
367  nConfigured++;
368  if (isConfigured())
369  for(std::vector<Interceptor>::const_iterator iter =
370  interceptors.begin();
371  iter != interceptors.end(); ++iter)
372  iter->second->init();
373 }
374 
376 {
377  splitResult = random.Uniform(1.0) >= split;
378 }
379 
380 void MVATrainerComputer::done()
381 {
382  if (isConfigured()) {
383  for(std::vector<Interceptor>::const_iterator iter =
384  interceptors.begin();
385  iter != interceptors.end(); ++iter)
386  iter->second->finish(doAutoSave);
387  nConfigured = 0;
388  }
389 }
390 
391 // implementation for MVATrainer
392 
393 const AtomicId MVATrainer::kTargetId("__TARGET__");
394 const AtomicId MVATrainer::kWeightId("__WEIGHT__");
395 
396 static const AtomicId kOutputId("__OUTPUT__");
397 
398 static bool isMagic(AtomicId id)
399 {
400  return id == MVATrainer::kTargetId ||
401  id == MVATrainer::kWeightId ||
402  id == kOutputId;
403 }
404 
406 {
407  std::string result("'");
408  for(std::string::const_iterator iter = in.begin();
409  iter != in.end(); ++iter) {
410  switch(*iter) {
411  case '\'':
412  result += "'\\''";
413  break;
414  default:
415  result += *iter;
416  }
417  }
418  result += '\'';
419  return result;
420 }
421 
423  const char *styleSheet) :
424  input(nullptr), output(nullptr), name("MVATrainer"),
425  doAutoSave(true), doCleanup(false),
426  doMonitoring(false), randomSeed(65539), crossValidation(0.0)
427 {
428  if (useXSLT) {
429  std::string sheet;
430  if (!styleSheet)
431  sheet = edm::FileInPath(
432  "PhysicsTools/MVATrainer/data/MVATrainer.xsl")
433  .fullPath();
434  else
435  sheet = styleSheet;
436 
437  std::string preproc = "xsltproc --xinclude " + escape(sheet) +
438  " " + escape(fileName);
439  xml.reset(new XMLDocument(fileName, preproc));
440  } else
441  xml.reset(new XMLDocument(fileName));
442 
443  DOMNode *node = xml->getRootNode();
444 
445  if (std::strcmp(XMLSimpleStr(node->getNodeName()), "MVATrainer") != 0)
446  throw cms::Exception("MVATrainer")
447  << "Invalid XML root node." << std::endl;
448 
449  enum State {
450  STATE_GENERAL,
451  STATE_FIRST,
452  STATE_MIDDLE,
453  STATE_LAST
454  } state = STATE_GENERAL;
455 
456  for(node = node->getFirstChild();
457  node; node = node->getNextSibling()) {
458  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
459  continue;
460 
461  std::string name = XMLSimpleStr(node->getNodeName());
462  DOMElement *elem = static_cast<DOMElement*>(node);
463 
464  switch(state) {
465  case STATE_GENERAL: {
466  if (name != "general")
467  throw cms::Exception("MVATrainer")
468  << "Expected general config as first "
469  "tag." << std::endl;
470 
471  for(DOMNode *subNode = elem->getFirstChild();
472  subNode; subNode = subNode->getNextSibling()) {
473  if (subNode->getNodeType() !=
474  DOMNode::ELEMENT_NODE)
475  continue;
476 
477  if (std::strcmp(XMLSimpleStr(
478  subNode->getNodeName()), "option") != 0)
479  throw cms::Exception("MVATrainer")
480  << "Expected option tag."
481  << std::endl;
482 
483  elem = static_cast<DOMElement*>(subNode);
484  name = XMLDocument::readAttribute<std::string>(
485  elem, "name");
487  elem->getTextContent());
488 
489  if (name == "id")
490  this->name = content;
491  else if (name == "trainfiles")
493  else
494  throw cms::Exception("MVATrainer")
495  << "Unknown option \""
496  << name << "\"." << std::endl;
497  }
498 
499  state = STATE_FIRST;
500  } break;
501  case STATE_FIRST: {
502  if (name != "input")
503  throw cms::Exception("MVATrainer")
504  << "Expected input config as second "
505  "tag." << std::endl;
506 
507  AtomicId id = XMLDocument::readAttribute<std::string>(
508  elem, "id");
509  input = new Source(id, true);
510  input->getOutputs().append(
511  createVariable(input, kTargetId,
514  input->getOutputs().append(
515  createVariable(input, kWeightId,
518  sources.insert(std::make_pair(id, input));
519  fillOutputVars(input->getOutputs(), input, elem);
520 
521  state = STATE_MIDDLE;
522  } break;
523  case STATE_MIDDLE: {
524  if (name == "output") {
525  AtomicId zero;
526  output = new TrainProcessor("output",
527  &zero, this);
529  state = STATE_LAST;
530  continue;
531  } else if (name != "processor")
532  throw cms::Exception("MVATrainer")
533  << "Unexpected tag after input "
534  "config." << std::endl;
535 
536  AtomicId id = XMLDocument::readAttribute<std::string>(
537  elem, "id");
538  std::string name =
539  XMLDocument::readAttribute<std::string>(
540  elem, "name");
541 
542  makeProcessor(elem, id, name.c_str());
543  } break;
544  case STATE_LAST:
545  throw cms::Exception("MVATrainer")
546  << "Unexpected tag found after output."
547  << std::endl;
548  break;
549  }
550  }
551 
552  if (state == STATE_FIRST)
553  throw cms::Exception("MVATrainer")
554  << "Expected input variable config." << std::endl;
555  else if (state == STATE_MIDDLE)
556  throw cms::Exception("MVATrainer")
557  << "Expected output variable config." << std::endl;
558 
559  if (trainFileMask.empty())
560  trainFileMask = this->name + "_%s%s.%s";
561 }
562 
564 {
565  if (monitoring.get())
566  monitoring->write();
567 
568  for(std::map<AtomicId, Source*>::const_iterator iter = sources.begin();
569  iter != sources.end(); iter++) {
571  dynamic_cast<TrainProcessor*>(iter->second);
572 
573  if (proc && doCleanup)
574  proc->cleanup();
575 
576  delete iter->second;
577  }
578  delete output;
579  std::for_each(variables.begin(), variables.end(),
580  deleter<SourceVariable>());
581 }
582 
584 {
585  for(std::vector<AtomicId>::const_iterator iter =
586  this->processors.begin();
587  iter != this->processors.end(); iter++) {
588  std::map<AtomicId, Source*>::const_iterator pos =
589  sources.find(*iter);
590  assert(pos != sources.end());
592  dynamic_cast<TrainProcessor*>(pos->second);
593  assert(source);
594 
595  if (source->load())
596  edm::LogInfo("MVATrainer")
597  << source->getId() << " configuration for \""
598  << (const char*)source->getName()
599  << "\" loaded from file.";
600  }
601 }
602 
604 {
605  doCleanup = false;
606 
607  for(std::vector<AtomicId>::const_iterator iter =
608  this->processors.begin();
609  iter != this->processors.end(); iter++) {
610  std::map<AtomicId, Source*>::const_iterator pos =
611  sources.find(*iter);
612  assert(pos != sources.end());
614  dynamic_cast<TrainProcessor*>(pos->second);
615  assert(source);
616 
617  if (source->isTrained())
618  source->save();
619  }
620 }
621 
622 void MVATrainer::makeProcessor(DOMElement *elem, AtomicId id, const char *name)
623 {
624  DOMElement *xmlInput = nullptr;
625  DOMElement *xmlConfig = nullptr;
626  DOMElement *xmlOutput = nullptr;
627  DOMElement *xmlData = nullptr;
628 
629  static struct NameExpect {
630  const char *tag;
631  bool mandatory;
632  DOMElement **elem;
633  } const expect[] = {
634  { "input", true, &xmlInput },
635  { "config", true, &xmlConfig },
636  { "output", true, &xmlOutput },
637  { "data", false, &xmlData },
638  { nullptr, }
639  };
640 
641  const NameExpect *cur = expect;
642  for(DOMNode *node = elem->getFirstChild();
643  node; node = node->getNextSibling()) {
644  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
645  continue;
646 
647  std::string tag = XMLSimpleStr(node->getNodeName());
648  DOMElement *elem = static_cast<DOMElement*>(node);
649 
650  if (!cur->tag)
651  throw cms::Exception("MVATrainer")
652  << "Superfluous tag " << tag
653  << "encountered in processor." << std::endl;
654  else if (tag != cur->tag && cur->mandatory)
655  throw cms::Exception("MVATrainer")
656  << "Expected tag " << cur->tag << ", got "
657  << tag << " instead in processor."
658  << std::endl;
659  else if (tag != cur->tag) {
660  cur++;
661  continue;
662  }
663  *(cur++)->elem = elem;
664  }
665 
666  while(cur->tag && !cur->mandatory)
667  cur++;
668  if (cur->tag)
669  throw cms::Exception("MVATrainer")
670  << "Unexpected end of processor configuration, "
671  << "expected tag " << cur->tag << "." << std::endl;
672 
673  std::unique_ptr<TrainProcessor> proc(
674  TrainProcessor::create(name, &id, this));
675  if (!proc.get())
676  throw cms::Exception("MVATrainer")
677  << "Variable processor trainer " << name
678  << " could not be instantiated. Most likely because"
679  " the trainer plugin for \"" << name << "\""
680  " does not exist." << std::endl;
681 
682  if (sources.find(id) != sources.end())
683  throw cms::Exception("MVATrainer")
684  << "Duplicate variable processor id "
685  << (const char*)id << "."
686  << std::endl;
687 
688  fillInputVars(proc->getInputs(), xmlInput);
689  fillOutputVars(proc->getOutputs(), proc.get(), xmlOutput);
690 
691  edm::LogInfo("MVATrainer")
692  << "Configuring " << (const char*)proc->getId()
693  << " \"" << (const char*)proc->getName() << "\".";
694  proc->configure(xmlConfig);
695 
696  sources.insert(std::make_pair(id, proc.release()));
697  processors.push_back(id);
698 }
699 
701  const std::string &ext,
702  const std::string &arg) const
703 {
704  std::string arg_ = !arg.empty() ? ("_" + arg) : "";
705  return stdStringPrintf(trainFileMask.c_str(),
706  (const char*)proc->getName(),
707  arg_.c_str(), ext.c_str());
708 }
709 
711 {
712  if (!doMonitoring)
713  return nullptr;
714 
715  if (!monitoring.get()) {
718  "monitoring", "", "root");
719  monitoring.reset(new TrainerMonitoring(fileName));
720  }
721 
722  return monitoring->book(name);
723 }
724 
726 {
727  std::map<AtomicId, Source*>::const_iterator pos = sources.find(source);
728  if (pos == sources.end())
729  return nullptr;
730 
731  return pos->second->getOutput(name);
732 }
733 
735  Variable::Flags flags)
736 {
737  SourceVariable *var = getVariable(source->getName(), name);
738  if (var)
739  return nullptr;
740 
741  var = new SourceVariable(source, name, flags);
742  variables.push_back(var);
743  return var;
744 }
745 
748 {
749  std::vector<SourceVariable*> tmp;
750  SourceVariable *target = nullptr;
751  SourceVariable *weight = nullptr;
752 
753  for(DOMNode *node = xml->getFirstChild(); node;
754  node = node->getNextSibling()) {
755  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
756  continue;
757 
758  if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
759  throw cms::Exception("MVATrainer")
760  << "Invalid input variable node." << std::endl;
761 
762  DOMElement *elem = static_cast<DOMElement*>(node);
763 
764  AtomicId source = XMLDocument::readAttribute<std::string>(
765  elem, "source");
766  AtomicId name = XMLDocument::readAttribute<std::string>(
767  elem, "name");
768 
769  SourceVariable *var = getVariable(source, name);
770  if (!var)
771  throw cms::Exception("MVATrainer")
772  << "Input variable " << (const char*)source
773  << ":" << (const char*)name
774  << " not found." << std::endl;
775 
776  if (XMLDocument::readAttribute<bool>(elem, "target", false)) {
777  if (target)
778  throw cms::Exception("MVATrainer")
779  << "Target variable defined twice"
780  << std::endl;
781  target = var;
782  }
783  if (XMLDocument::readAttribute<bool>(elem, "weight", false)) {
784  if (weight)
785  throw cms::Exception("MVATrainer")
786  << "Weight variable defined twice"
787  << std::endl;
788  weight = var;
789  }
790 
791  tmp.push_back(var);
792  }
793 
794  if (!weight) {
795  weight = input->getOutput(kWeightId);
796  assert(weight);
797  tmp.insert(tmp.begin() +
798  (target == input->getOutput(kTargetId)),
799  1, weight);
800  }
801  if (!target) {
802  target = input->getOutput(kTargetId);
803  assert(target);
804  tmp.insert(tmp.begin(), 1, target);
805  }
806 
807  unsigned int n = 0;
808  for(std::vector<SourceVariable*>::const_iterator iter = variables.begin();
809  iter != variables.end(); iter++) {
810  std::vector<SourceVariable*>::const_iterator pos =
811  std::find(tmp.begin(), tmp.end(), *iter);
812  if (pos == tmp.end())
813  continue;
814 
816  if (*iter == target)
818  else if (*iter == weight)
820  else
822 
823  if (vars.append(*iter, magic, pos - tmp.begin())) {
824  AtomicId source = (*iter)->getSource()->getName();
825  AtomicId name = (*iter)->getName();
826  throw cms::Exception("MVATrainer")
827  << "Input variable " << (const char*)source
828  << ":" << (const char*)name
829  << " defined twice." << std::endl;
830  }
831 
832  n++;
833  }
834 
835  assert(tmp.size() == n);
836 }
837 
840 {
841  for(DOMNode *node = xml->getFirstChild(); node;
842  node = node->getNextSibling()) {
843  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
844  continue;
845 
846  if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
847  throw cms::Exception("MVATrainer")
848  << "Invalid output variable node."
849  << std::endl;
850 
851  DOMElement *elem = static_cast<DOMElement*>(node);
852 
853  AtomicId name = XMLDocument::readAttribute<std::string>(
854  elem, "name");
855  if (!name)
856  throw cms::Exception("MVATrainer")
857  << "Output variable tag missing name."
858  << std::endl;
859  if (isMagic(name))
860  throw cms::Exception("MVATrainer")
861  << "Cannot use magic variable names in output."
862  << std::endl;
863 
865 
866  if (XMLDocument::readAttribute<bool>(elem, "optional", true))
868  (flags | Variable::FLAG_OPTIONAL);
869 
870  if (XMLDocument::readAttribute<bool>(elem, "multiple", true))
872  (flags | Variable::FLAG_MULTIPLE);
873 
874  SourceVariable *var = createVariable(source, name, flags);
875  if (!var || vars.append(var))
876  throw cms::Exception("MVATrainer")
877  << "Output variable "
878  << (const char*)source->getName()
879  << ":" << (const char*)name
880  << " defined twice." << std::endl;
881  }
882 }
883 
884 void
886  const std::vector<CalibratedProcessor> &procs,
887  bool withTarget) const
888 {
889  std::map<SourceVariable*, unsigned int> vars;
890  unsigned int size = 0;
891 
892  MVATrainerComputer *trainCalib =
893  dynamic_cast<MVATrainerComputer*>(calib);
894 
895  for(unsigned int i = 0;
896  i < input->getOutputs().size(true); i++) {
897  if (i < 2 && !withTarget)
898  continue;
899 
901  vars[var] = size++;
902 
903  Calibration::Variable calibVar;
904  calibVar.name = (const char*)var->getName();
905  calib->inputSet.push_back(calibVar);
906  if (trainCalib)
907  trainCalib->addFlag(var->getFlags());
908  }
909 
910  for(std::vector<CalibratedProcessor>::const_iterator iter =
911  procs.begin(); iter != procs.end(); iter++) {
912  bool isInterceptor = dynamic_cast<BaseInterceptor*>(
913  iter->calib) != nullptr;
914 
915  BitSet inputSet(size);
916 
917  unsigned int last = 0;
918  std::vector<SourceVariable*> inoutVars;
919  if (iter->processor)
920  inoutVars = iter->processor->getInputs().get(
921  isInterceptor);
922  for(std::vector<SourceVariable*>::const_iterator iter2 =
923  inoutVars.begin(); iter2 != inoutVars.end(); iter2++) {
925  unsigned int>::const_iterator pos =
926  vars.find(*iter2);
927 
928  assert(pos != vars.end());
929 
930  if (pos->second < last)
931  throw cms::Exception("MVATrainer")
932  << "Input variables not declared "
933  "in order of appearance in \""
934  << (const char*)iter->processor->getName()
935  << "\"." << std::endl;
936 
937  inputSet[last = pos->second] = true;
938  }
939 
940  assert(!isInterceptor || withTarget);
941 
942  iter->calib->inputVars = Calibration::convert(inputSet);
943 
944  calib->output = size;
945 
946  if (isInterceptor) {
947  size++;
948  continue;
949  }
950 
951  calib->addProcessor(iter->calib);
952 
953  inoutVars = iter->processor->getOutputs().get();
954  for(std::vector<SourceVariable*>::const_iterator iter =
955  inoutVars.begin(); iter != inoutVars.end(); iter++) {
956 
957  vars[*iter] = size++;
958  }
959  }
960 
961  if (output->getInputs().size() != 1)
962  throw cms::Exception("MVATrainer")
963  << "Exactly one output variable has to be specified."
964  << std::endl;
965 
966  SourceVariable *outVar = output->getInputs().get()[0];
967  std::map<SourceVariable*, unsigned int>::const_iterator pos =
968  vars.find(outVar);
969  if (pos != vars.end())
970  calib->output = pos->second;
971 }
972 
975  const AtomicId *train) const
976 {
977  std::map<AtomicId, TrainInterceptor*> interceptors;
978  std::vector<MVATrainerComputer::Interceptor> baseInterceptors;
979  std::vector<CalibratedProcessor> processors;
980 
981  BaseInterceptor *interceptor = new InitInterceptor;
982  baseInterceptors.push_back(std::make_pair(0, interceptor));
983  processors.push_back(CalibratedProcessor(nullptr, interceptor));
984 
985  for(const AtomicId *iter = train; *iter; iter++) {
987  if (*iter == kOutputId)
988  source = output;
989  else {
990  std::map<AtomicId, Source*>::const_iterator pos =
991  sources.find(*iter);
992  assert(pos != sources.end());
993  source = dynamic_cast<TrainProcessor*>(pos->second);
994  }
995  assert(source);
996 
997  interceptors[*iter] = new TrainInterceptor(source);
998  }
999 
1000  auto_cleaner<Calibration::VarProcessor> autoClean;
1001 
1002  std::set<AtomicId> done;
1003  for(const AtomicId *iter = compute; *iter; iter++) {
1004  if (done.erase(*iter))
1005  continue;
1006 
1007  std::map<AtomicId, Source*>::const_iterator pos =
1008  sources.find(*iter);
1009  assert(pos != sources.end());
1011  dynamic_cast<TrainProcessor*>(pos->second);
1012  assert(source);
1013  assert(source->isTrained());
1014 
1015  Calibration::VarProcessor *proc = source->getCalibration();
1016  if (!proc)
1017  continue;
1018 
1019  autoClean.add(proc);
1020  processors.push_back(CalibratedProcessor(source, proc));
1021 
1023  dynamic_cast<Calibration::ProcForeach*>(proc);
1024  if (looper) {
1025  std::vector<AtomicId>::const_iterator pos2 =
1026  std::find(this->processors.begin(),
1027  this->processors.end(), *iter);
1028  assert(pos2 != this->processors.end());
1029  ++pos2;
1030  unsigned int n = 0;
1031  for(int i = 0; i < (int)looper->nProcs; ++i, ++pos2) {
1032  assert(pos2 != this->processors.end());
1033 
1034  const AtomicId *iter2 = compute;
1035  while(*iter2) {
1036  if (*iter2 == *pos2)
1037  break;
1038  iter2++;
1039  }
1040 
1041  if (*iter2) {
1042  n++;
1043  done.insert(*iter2);
1044  pos = sources.find(*iter2);
1045  assert(pos != sources.end());
1046  TrainProcessor *source =
1047  dynamic_cast<TrainProcessor*>(
1048  pos->second);
1049  assert(source);
1050  assert(source->isTrained());
1051 
1052  proc = source->getCalibration();
1053  if (proc) {
1054  autoClean.add(proc);
1055  processors.push_back(
1057  source, proc));
1058  }
1059  }
1060 
1061  std::map<AtomicId, TrainInterceptor*>::iterator
1062  pos3 = interceptors.find(*pos2);
1063  if (pos3 != interceptors.end()) {
1064  n++;
1065  baseInterceptors.push_back(
1066  std::make_pair(processors.size(),
1067  pos3->second));
1068  processors.push_back(
1070  pos3->second->getProcessor(),
1071  pos3->second));
1072  interceptors.erase(pos3);
1073  }
1074  }
1075 
1076  looper->nProcs = n;
1077  if (!n) {
1078  baseInterceptors.pop_back();
1079  processors.pop_back();
1080  }
1081  }
1082  }
1083 
1084  for(std::map<AtomicId, TrainInterceptor*>::const_iterator iter =
1085  interceptors.begin(); iter != interceptors.end(); ++iter) {
1086 
1087  TrainProcessor *proc = iter->second->getProcessor();
1088  baseInterceptors.push_back(std::make_pair(processors.size(),
1089  iter->second));
1090  processors.push_back(CalibratedProcessor(proc, iter->second));
1091  }
1092 
1093  std::unique_ptr<Calibration::MVAComputer> calib(
1094  new MVATrainerComputer(baseInterceptors, doAutoSave,
1096 
1097  connectProcessors(calib.get(), processors, true);
1098 
1099  return calib.release();
1100 }
1101 
1103 {
1104  MVATrainerComputer *calib =
1105  dynamic_cast<MVATrainerComputer*>(trainCalibration);
1106 
1107  if (!calib)
1108  throw cms::Exception("MVATrainer")
1109  << "Invalid training calibration passed to "
1110  "doneTraining()" << std::endl;
1111 
1112  calib->done();
1113 }
1114 
1115 std::vector<AtomicId> MVATrainer::findFinalProcessors() const
1116 {
1117  std::set<Source*> toCheck;
1118  toCheck.insert(output);
1119 
1120  std::set<Source*> done;
1121  while(!toCheck.empty()) {
1122  Source *source = *toCheck.begin();
1123  toCheck.erase(toCheck.begin());
1124 
1125  std::vector<SourceVariable*> inputs = source->inputs.get();
1126  for(std::vector<SourceVariable*>::const_iterator iter =
1127  inputs.begin(); iter != inputs.end(); ++iter) {
1128  source = (*iter)->getSource();
1129  if (done.insert(source).second)
1130  toCheck.insert(source);
1131  }
1132  }
1133 
1134  std::vector<AtomicId> result;
1135  for(std::vector<AtomicId>::const_iterator iter = processors.begin();
1136  iter != processors.end(); ++iter) {
1137  std::map<AtomicId, Source*>::const_iterator pos =
1138  sources.find(*iter);
1139  if (pos != sources.end() && done.count(pos->second))
1140  result.push_back(*iter);
1141  }
1142 
1143  return result;
1144 }
1145 
1147 {
1148  std::vector<CalibratedProcessor> processors;
1149 
1150  std::unique_ptr<Calibration::MVAComputer> calib(
1152 
1153  std::vector<AtomicId> used = findFinalProcessors();
1154  for(std::vector<AtomicId>::const_iterator iter = used.begin();
1155  iter != used.end(); iter++) {
1156  std::map<AtomicId, Source*>::const_iterator pos =
1157  sources.find(*iter);
1158  assert(pos != sources.end());
1160  dynamic_cast<TrainProcessor*>(pos->second);
1161  assert(source);
1162  if (!source->isTrained())
1163  return nullptr;
1164 
1165  Calibration::VarProcessor *proc = source->getCalibration();
1166  if (!proc)
1167  continue;
1168 
1169  Calibration::ProcForeach *foreach =
1170  dynamic_cast<Calibration::ProcForeach*>(proc);
1171  if (foreach) {
1172  std::vector<AtomicId>::const_iterator begin =
1173  std::find(this->processors.begin(),
1174  this->processors.end(), *iter);
1175  assert(this->processors.end() - begin >
1176  (int)(foreach->nProcs + 1));
1177  ++begin;
1178  std::vector<AtomicId>::const_iterator end =
1179  begin + foreach->nProcs;
1180  foreach->nProcs = 0;
1181  for(std::vector<AtomicId>::const_iterator iter2 =
1182  iter; iter2 != used.end(); ++iter2)
1183  if (std::find(begin, end, *iter2) != end)
1184  foreach->nProcs++;
1185  }
1186 
1187  processors.push_back(CalibratedProcessor(source, proc));
1188  }
1189 
1190  connectProcessors(calib.get(), processors, false);
1191 
1192  return calib.release();
1193 }
1194 
1195 void MVATrainer::findUntrainedComputers(std::vector<AtomicId> &compute,
1196  std::vector<AtomicId> &train) const
1197 {
1198  compute.clear();
1199  train.clear();
1200 
1201  std::set<Source*> trainedSources;
1202  trainedSources.insert(input);
1203 
1204  for(std::vector<AtomicId>::const_iterator iter =
1205  processors.begin(); iter != processors.end(); iter++) {
1206  std::map<AtomicId, Source*>::const_iterator pos =
1207  sources.find(*iter);
1208  assert(pos != sources.end());
1209  TrainProcessor *proc =
1210  dynamic_cast<TrainProcessor*>(pos->second);
1211  assert(proc);
1212 
1213  bool trainedDeps = true;
1214  std::vector<SourceVariable*> inputVars =
1215  proc->getInputs().get();
1216  for(std::vector<SourceVariable*>::const_iterator iter2 =
1217  inputVars.begin(); iter2 != inputVars.end(); iter2++) {
1218  if (trainedSources.find((*iter2)->getSource())
1219  == trainedSources.end()) {
1220  trainedDeps = false;
1221  break;
1222  }
1223  }
1224 
1225  if (!trainedDeps)
1226  continue;
1227 
1228  if (proc->isTrained()) {
1229  trainedSources.insert(proc);
1230  compute.push_back(proc->getName());
1231  } else
1232  train.push_back(proc->getName());
1233  }
1234 
1235  if (doMonitoring && !output->isTrained() &&
1236  trainedSources.find(output->getInputs().get()[0]->getSource())
1237  != trainedSources.end())
1238  train.push_back(kOutputId);
1239 }
1240 
1242 {
1243  std::vector<AtomicId> compute, train;
1244  findUntrainedComputers(compute, train);
1245 
1246  if (train.empty())
1247  return nullptr;
1248 
1249  compute.push_back(nullptr);
1250  train.push_back(nullptr);
1251 
1252  return makeTrainCalibration(&compute.front(), &train.front());
1253 }
1254 
1255 } // namespace PhysicsTools
size
Write out results.
static std::string escape(const std::string &in)
Definition: MVATrainer.cc:405
unsigned int nConfigured
Definition: MVATrainer.cc:136
virtual std::vector< VarProcessor * > getProcessors() const
Definition: MVAComputer.cc:178
bool isTrained() const
Definition: Source.h:24
TrainProcessor *const proc
Definition: MVATrainer.cc:101
int init
Definition: HydjetWrapper.h:67
std::vector< Variable::Flags > flags
Definition: MVATrainer.cc:135
Definition: weight.py:1
SourceVariable * getVariable(AtomicId source, AtomicId name) const
Definition: MVATrainer.cc:725
#define XERCES_CPP_NAMESPACE_QUALIFIER
Definition: LHERunInfo.h:16
TRandom random
Definition: MVATrainer.cc:138
#define nullptr
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:20
A arg
Definition: Factorize.h:37
static const AtomicId kTargetId
Definition: MVATrainer.h:59
SourceVariable * getOutput(AtomicId name) const
Definition: Source.h:21
const SourceVariableSet & getInputs() const
Definition: Source.h:26
void fillOutputVars(SourceVariableSet &vars, Source *source, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
Definition: MVATrainer.cc:838
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:31
unsigned int weightIdx
Definition: MVATrainer.cc:99
static std::string const input
Definition: EdmProvDump.cc:44
Definition: looper.py:1
bool doAutoSave
Definition: MVATrainer.cc:137
static const AtomicId kWeightId
Definition: MVATrainer.h:60
Calibration::MVAComputer * makeTrainCalibration(const AtomicId *compute, const AtomicId *train) const
Definition: MVATrainer.cc:974
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
Main interface class to the generic discriminator computer framework.
Definition: MVAComputer.h:39
const AtomicId getName() const
Definition: Variable.h:143
std::vector< AtomicId > processors
Definition: MVATrainer.h:102
Flags getFlags() const
Definition: Variable.h:144
const SourceVariableSet & getOutputs() const
Definition: Source.h:27
std::vector< AtomicId > findFinalProcessors() const
Definition: MVATrainer.cc:1115
A compact container for storing single bits.
Definition: BitSet.h:29
std::vector< SourceVariable * > get(bool withMagic=false) const
void addProcessor(const VarProcessor *proc)
Definition: MVAComputer.cc:183
#define end
Definition: vmac.h:39
unsigned int targetIdx
Definition: MVATrainer.cc:98
std::string trainFileName(const TrainProcessor *proc, const std::string &ext, const std::string &arg="") const
Definition: MVATrainer.cc:700
SourceVariableSet inputs
Definition: Source.h:39
void connectProcessors(Calibration::MVAComputer *calib, const std::vector< CalibratedProcessor > &procs, bool withTarget) const
Definition: MVATrainer.cc:885
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:18
std::vector< T * > clean
Definition: MVATrainer.cc:156
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
static Base_t * create(const char *name, const CalibBase_t *calib, Parent_t *parent=0)
void add(std::map< std::string, TH1 * > &h, TH1 *hist)
bool splitResult
Definition: MVATrainer.cc:140
void doneTraining(Calibration::MVAComputer *trainCalibration) const
Definition: MVATrainer.cc:1102
std::unique_ptr< TrainerMonitoring > monitoring
Definition: MVATrainer.h:106
TrainerMonitoring::Module * bookMonitor(const std::string &name)
Definition: MVATrainer.cc:710
std::vector< Interceptor > interceptors
Definition: MVATrainer.cc:134
size_type size(bool withMagic=false) const
bool append(SourceVariable *var, Magic magic=kRegular, int offset=-1)
std::unique_ptr< XMLDocument > xml
Definition: MVATrainer.h:107
Calibration::MVAComputer * getCalibration() const
Definition: MVATrainer.cc:1146
def compute(min, max)
std::vector< SourceVariable * > variables
Definition: MVATrainer.h:101
std::vector< Variable > inputSet
Definition: MVAComputer.h:234
void fillInputVars(SourceVariableSet &vars, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
Definition: MVATrainer.cc:746
SourceVariable * find(AtomicId name) const
std::vector< std::vector< double > > tmp
Definition: MVATrainer.cc:100
MVATrainer(const std::string &fileName, bool useXSLT=false, const char *styleSheet=0)
Definition: MVATrainer.cc:422
#define begin
Definition: vmac.h:32
AtomicId getName() const
Definition: Source.h:19
#define foreach
void makeProcessor(XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *elem, AtomicId id, const char *name)
Definition: MVATrainer.cc:622
static std::vector< std::string > split(const std::string line, char delim)
Definition: MLP.cc:18
std::string fullPath() const
Definition: FileInPath.cc:197
std::string trainFileMask
Definition: MVATrainer.h:108
State
Definition: hltDiff.cc:288
PhysicsTools::BitSet convert(const BitSet &bitSet)
constructs BitSet container from persistent representation
Definition: BitSet.cc:38
void findUntrainedComputers(std::vector< AtomicId > &compute, std::vector< AtomicId > &train) const
Definition: MVATrainer.cc:1195
TrainProcessor * output
Definition: MVATrainer.h:104
Calibration::MVAComputer * getTrainCalibration() const
Definition: MVATrainer.cc:1241
Definition: memstream.h:15
long double T
static std::string const source
Definition: EdmProvDump.cc:43
save
Definition: cuy.py:1163
static const AtomicId kOutputId("__OUTPUT__")
static std::string stdStringVPrintf(const char *format, std::va_list va)
Definition: MVATrainer.cc:160
static std::string stdStringPrintf(const char *format,...)
Definition: MVATrainer.cc:183
static bool isMagic(AtomicId id)
Definition: MVATrainer.cc:398
SourceVariable * createVariable(Source *source, AtomicId name, Variable::Flags flags)
Definition: MVATrainer.cc:734