CMS 3D CMS Logo

MVATrainer.cc
Go to the documentation of this file.
1 #include <cassert>
2 #include <functional>
3 #include <ext/functional>
4 #include <algorithm>
5 #include <iostream>
6 #include <cstdarg>
7 #include <cstring>
8 #include <cstdio>
9 #include <string>
10 #include <vector>
11 #include <memory>
12 #include <map>
13 #include <set>
14 
15 #include <xercesc/dom/DOM.hpp>
16 
17 #include <TRandom.h>
18 
22 
27 
37 
39 
40 namespace PhysicsTools {
41 
42 namespace { // anonymous
43  class MVATrainerComputer;
44 
45  class BaseInterceptor : public Calibration::Interceptor {
46  public:
47  BaseInterceptor() : calib(nullptr) {}
48  ~BaseInterceptor() override {}
49 
50  inline void setCalibration(MVATrainerComputer *calib)
51  { this->calib = calib; }
52 
53  std::vector<Variable::Flags>
54  configure(const MVAComputer *computer, unsigned int n,
55  const std::vector<Variable::Flags> &flags) override = 0;
56 
57  double
58  intercept(const std::vector<double> *values) const override = 0;
59 
60  virtual void init() {}
61  virtual void finish(bool save) {}
62 
63  protected:
64  MVATrainerComputer *calib;
65  };
66 
67  class InitInterceptor : public BaseInterceptor {
68  public:
69  InitInterceptor() {}
70  ~InitInterceptor() override {}
71 
72  std::vector<Variable::Flags>
73  configure(const MVAComputer *computer, unsigned int n,
74  const std::vector<Variable::Flags> &flags) override;
75 
76  double
77  intercept(const std::vector<double> *values) const override;
78  };
79 
80  class TrainInterceptor : public BaseInterceptor {
81  public:
82  TrainInterceptor(TrainProcessor *proc) : proc(proc) {}
83  ~TrainInterceptor() override {}
84 
85  inline TrainProcessor *getProcessor() const { return proc; }
86 
87  std::vector<Variable::Flags>
88  configure(const MVAComputer *computer, unsigned int n,
89  const std::vector<Variable::Flags> &flags) override;
90 
91  double
92  intercept(const std::vector<double> *values) const override;
93 
94  void init() override;
95  void finish(bool save) override;
96 
97  private:
98  unsigned int targetIdx;
99  unsigned int weightIdx;
100  mutable std::vector<std::vector<double> > tmp;
101  TrainProcessor *const proc;
102  };
103 
104  class MVATrainerComputer : public TrainMVAComputerCalibration {
105  public:
106  typedef std::pair<unsigned int, BaseInterceptor*> Interceptor;
107 
108  MVATrainerComputer(const std::vector<Interceptor>
109  &interceptors,
110  bool autoSave, UInt_t seed, double split);
111 
112  ~MVATrainerComputer() override;
113 
114  std::vector<Calibration::VarProcessor*>
115  getProcessors() const override;
116  void initFlags(std::vector<Variable::Flags>
117  &flags) const override;
118 
119  void configured(BaseInterceptor *interceptor) const;
120  void next();
121  void done();
122 
123  inline void addFlag(Variable::Flags flag)
124  { flags.push_back(flag); }
125 
126  inline bool useForTraining() const { return splitResult; }
127  inline bool useForTesting() const
128  { return split <= 0.0 || !splitResult; }
129 
130  inline bool isConfigured() const
131  { return nConfigured == interceptors.size(); }
132 
133  private:
134  std::vector<Interceptor> interceptors;
135  std::vector<Variable::Flags> flags;
136  mutable unsigned int nConfigured;
138  TRandom random;
139  double split;
141  };
142 
143  // useful litte helpers
144 
145  template<typename T>
146  static inline void deleter(T *ptr) { delete ptr; }
147 
148  template<typename T>
149  struct auto_cleaner {
150  inline ~auto_cleaner()
151  { std::for_each(clean.begin(), clean.end(), deleter<T>); }
152 
153  inline void add(T *ptr) { clean.push_back(ptr); }
154  std::vector<T*> clean;
155  };
156 } // anonymous namespace
157 
158 static std::string stdStringVPrintf(const char *format, std::va_list va)
159 {
160  unsigned int size = std::min<unsigned int>(128, std::strlen(format));
161  char *buffer = new char[size];
162  for(;;) {
163  int n = std::vsnprintf(buffer, size, format, va);
164  if (n >= 0 && (unsigned int)n < size)
165  break;
166 
167  if (n >= 0)
168  size = n + 1;
169  else
170  size *= 2;
171 
172  delete[] buffer;
173  buffer = new char[size];
174  }
175 
176  std::string result(buffer);
177  delete[] buffer;
178  return result;
179 }
180 
181 static std::string stdStringPrintf(const char *format, ...)
182 {
183  std::va_list va;
184  va_start(va, format);
185  std::string result = stdStringVPrintf(format, va);
186  va_end(va);
187  return result;
188 }
189 
190 // implementation for InitInterceptor
191 
192 std::vector<Variable::Flags>
193 InitInterceptor::configure(const MVAComputer *computer, unsigned int n,
194  const std::vector<Variable::Flags> &flags)
195 {
196  calib->configured(this);
197  return std::vector<Variable::Flags>(n, Variable::FLAG_ALL);
198 }
199 
200 double
201 InitInterceptor::intercept(const std::vector<double> *values) const
202 {
203  calib->next();
204  return 0.0;
205 }
206 
207 // implementation for TrainInterceptor
208 
209 std::vector<Variable::Flags>
210 TrainInterceptor::configure(const MVAComputer *computer, unsigned int n,
211  const std::vector<Variable::Flags> &flags)
212 {
213  const SourceVariableSet &inputSet =
214  const_cast<const TrainProcessor*>(proc)->getInputs();
217 
218  std::vector<SourceVariable*> inputs = inputSet.get(true);
219 
220  std::vector<SourceVariable*>::const_iterator pos;
221  pos = std::find(inputs.begin(), inputs.end(), target);
222  assert(pos != inputs.end());
223  targetIdx = pos - inputs.begin();
224  pos = std::find(inputs.begin(), inputs.end(), weight);
225  assert(pos != inputs.end());
226  weightIdx = pos - inputs.begin();
227 
228  calib->configured(this);
229 
230  std::vector<Variable::Flags> result = flags;
231  if (targetIdx < weightIdx) {
232  result.erase(result.begin() + weightIdx);
233  result.erase(result.begin() + targetIdx);
234  } else {
235  result.erase(result.begin() + targetIdx);
236  result.erase(result.begin() + weightIdx);
237  }
238 
239  proc->passFlags(result);
240 
241  result.clear();
242  result.resize(n, proc->getDefaultFlags());
243  result[targetIdx] = Variable::FLAG_NONE;
245 
246  if (targetIdx >= 2 || weightIdx >= 2)
247  tmp.resize(n - 2);
248 
249  return result;
250 }
251 
253 {
254  edm::LogInfo("MVATrainer")
255  << "TrainProcessor \"" << (const char*)proc->getName()
256  << "\" training iteration starting...";
257 
258  proc->doTrainBegin();
259 }
260 
261 double
262 TrainInterceptor::intercept(const std::vector<double> *values) const
263 {
264  if (values[targetIdx].size() != 1) {
265  if (values[targetIdx].empty())
266  throw cms::Exception("MVATrainer")
267  << "Trainer input lacks target variable."
268  << std::endl;
269  else
270  throw cms::Exception("MVATrainer")
271  << "Multiple targets supplied in input."
272  << std::endl;
273  }
274  double target = values[targetIdx].front();
275 
276  double weight = 1.0;
277  if (values[weightIdx].size() > 1)
278  throw cms::Exception("MVATrainer")
279  << "Multiple weights supplied in input."
280  << std::endl;
281  else if (values[weightIdx].size() == 1)
282  weight = values[weightIdx].front();
283 
284  if (tmp.empty())
285  proc->doTrainData(values + 2, target > 0.5, weight,
286  calib->useForTraining(),
287  calib->useForTesting());
288  else {
289  std::vector<std::vector<double> >::iterator pos = tmp.begin();
290  for(unsigned int i = 0; pos != tmp.end(); i++)
291  if (i != targetIdx && i != weightIdx)
292  *pos++ = values[i];
293 
294  proc->doTrainData(&tmp.front(), target > 0.5, weight,
295  calib->useForTraining(),
296  calib->useForTesting());
297  }
298 
299  return target;
300 }
301 
302 void TrainInterceptor::finish(bool save)
303 {
304  proc->doTrainEnd();
305 
306  edm::LogInfo("MVATrainer")
307  << "... processor \"" << (const char*)proc->getName()
308  << "\" training iteration done.";
309 
310  if (proc->isTrained()) {
311  edm::LogInfo("MVATrainer")
312  << "* Completed training of \""
313  << (const char*)proc->getName() << "\".";
314 
315  if (save)
316  proc->save();
317  }
318 }
319 
320 // implementation for MVATrainerComputer
321 
322 MVATrainerComputer::MVATrainerComputer(const std::vector<Interceptor>
323  &interceptors, bool autoSave,
324  UInt_t seed, double split) :
325  interceptors(interceptors), nConfigured(0), doAutoSave(autoSave),
326  random(seed), split(split)
327 {
328  for(std::vector<Interceptor>::const_iterator iter =
329  interceptors.begin(); iter != interceptors.end(); ++iter)
330  iter->second->setCalibration(this);
331 }
332 
333 MVATrainerComputer::~MVATrainerComputer()
334 {
335  done();
336 
337  for(std::vector<Interceptor>::const_iterator iter =
338  interceptors.begin(); iter != interceptors.end(); ++iter)
339  delete iter->second;
340 }
341 
342 std::vector<Calibration::VarProcessor*>
343 MVATrainerComputer::getProcessors() const
344 {
345  std::vector<Calibration::VarProcessor*> processors =
347 
348  for(std::vector<Interceptor>::const_iterator iter =
349  interceptors.begin(); iter != interceptors.end(); ++iter)
350 
351  processors.insert(processors.begin() + iter->first,
352  1, iter->second);
353 
354  return processors;
355 }
356 
357 void MVATrainerComputer::initFlags(std::vector<Variable::Flags> &flags) const
358 {
359  assert(flags.size() == this->flags.size());
360  flags = this->flags;
361 }
362 
363 void MVATrainerComputer::configured(BaseInterceptor *interceptor) const
364 {
365  nConfigured++;
366  if (isConfigured())
367  for(std::vector<Interceptor>::const_iterator iter =
368  interceptors.begin();
369  iter != interceptors.end(); ++iter)
370  iter->second->init();
371 }
372 
374 {
375  splitResult = random.Uniform(1.0) >= split;
376 }
377 
378 void MVATrainerComputer::done()
379 {
380  if (isConfigured()) {
381  for(std::vector<Interceptor>::const_iterator iter =
382  interceptors.begin();
383  iter != interceptors.end(); ++iter)
384  iter->second->finish(doAutoSave);
385  nConfigured = 0;
386  }
387 }
388 
389 // implementation for MVATrainer
390 
391 const AtomicId MVATrainer::kTargetId("__TARGET__");
392 const AtomicId MVATrainer::kWeightId("__WEIGHT__");
393 
394 static const AtomicId kOutputId("__OUTPUT__");
395 
396 static bool isMagic(AtomicId id)
397 {
398  return id == MVATrainer::kTargetId ||
399  id == MVATrainer::kWeightId ||
400  id == kOutputId;
401 }
402 
404 {
405  std::string result("'");
406  for(std::string::const_iterator iter = in.begin();
407  iter != in.end(); ++iter) {
408  switch(*iter) {
409  case '\'':
410  result += "'\\''";
411  break;
412  default:
413  result += *iter;
414  }
415  }
416  result += '\'';
417  return result;
418 }
419 
421  const char *styleSheet) :
422  input(nullptr), output(nullptr), name("MVATrainer"),
423  doAutoSave(true), doCleanup(false),
424  doMonitoring(false), randomSeed(65539), crossValidation(0.0)
425 {
426  if (useXSLT) {
427  std::string sheet;
428  if (!styleSheet)
429  sheet = edm::FileInPath(
430  "PhysicsTools/MVATrainer/data/MVATrainer.xsl")
431  .fullPath();
432  else
433  sheet = styleSheet;
434 
435  std::string preproc = "xsltproc --xinclude " + escape(sheet) +
436  " " + escape(fileName);
437  xml.reset(new XMLDocument(fileName, preproc));
438  } else
439  xml.reset(new XMLDocument(fileName));
440 
441  DOMNode *node = xml->getRootNode();
442 
443  if (std::strcmp(XMLSimpleStr(node->getNodeName()), "MVATrainer") != 0)
444  throw cms::Exception("MVATrainer")
445  << "Invalid XML root node." << std::endl;
446 
447  enum State {
448  STATE_GENERAL,
449  STATE_FIRST,
450  STATE_MIDDLE,
451  STATE_LAST
452  } state = STATE_GENERAL;
453 
454  for(node = node->getFirstChild();
455  node; node = node->getNextSibling()) {
456  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
457  continue;
458 
459  std::string name = XMLSimpleStr(node->getNodeName());
460  DOMElement *elem = static_cast<DOMElement*>(node);
461 
462  switch(state) {
463  case STATE_GENERAL: {
464  if (name != "general")
465  throw cms::Exception("MVATrainer")
466  << "Expected general config as first "
467  "tag." << std::endl;
468 
469  for(DOMNode *subNode = elem->getFirstChild();
470  subNode; subNode = subNode->getNextSibling()) {
471  if (subNode->getNodeType() !=
472  DOMNode::ELEMENT_NODE)
473  continue;
474 
475  if (std::strcmp(XMLSimpleStr(
476  subNode->getNodeName()), "option") != 0)
477  throw cms::Exception("MVATrainer")
478  << "Expected option tag."
479  << std::endl;
480 
481  elem = static_cast<DOMElement*>(subNode);
482  name = XMLDocument::readAttribute<std::string>(
483  elem, "name");
485  elem->getTextContent());
486 
487  if (name == "id")
488  this->name = content;
489  else if (name == "trainfiles")
491  else
492  throw cms::Exception("MVATrainer")
493  << "Unknown option \""
494  << name << "\"." << std::endl;
495  }
496 
497  state = STATE_FIRST;
498  } break;
499  case STATE_FIRST: {
500  if (name != "input")
501  throw cms::Exception("MVATrainer")
502  << "Expected input config as second "
503  "tag." << std::endl;
504 
505  AtomicId id = XMLDocument::readAttribute<std::string>(
506  elem, "id");
507  input = new Source(id, true);
508  input->getOutputs().append(
509  createVariable(input, kTargetId,
512  input->getOutputs().append(
513  createVariable(input, kWeightId,
516  sources.insert(std::make_pair(id, input));
517  fillOutputVars(input->getOutputs(), input, elem);
518 
519  state = STATE_MIDDLE;
520  } break;
521  case STATE_MIDDLE: {
522  if (name == "output") {
523  AtomicId zero;
524  output = new TrainProcessor("output",
525  &zero, this);
527  state = STATE_LAST;
528  continue;
529  } else if (name != "processor")
530  throw cms::Exception("MVATrainer")
531  << "Unexpected tag after input "
532  "config." << std::endl;
533 
534  AtomicId id = XMLDocument::readAttribute<std::string>(
535  elem, "id");
536  std::string name =
537  XMLDocument::readAttribute<std::string>(
538  elem, "name");
539 
540  makeProcessor(elem, id, name.c_str());
541  } break;
542  case STATE_LAST:
543  throw cms::Exception("MVATrainer")
544  << "Unexpected tag found after output."
545  << std::endl;
546  break;
547  }
548  }
549 
550  if (state == STATE_FIRST)
551  throw cms::Exception("MVATrainer")
552  << "Expected input variable config." << std::endl;
553  else if (state == STATE_MIDDLE)
554  throw cms::Exception("MVATrainer")
555  << "Expected output variable config." << std::endl;
556 
557  if (trainFileMask.empty())
558  trainFileMask = this->name + "_%s%s.%s";
559 }
560 
562 {
563  if (monitoring.get())
564  monitoring->write();
565 
566  for(std::map<AtomicId, Source*>::const_iterator iter = sources.begin();
567  iter != sources.end(); iter++) {
569  dynamic_cast<TrainProcessor*>(iter->second);
570 
571  if (proc && doCleanup)
572  proc->cleanup();
573 
574  delete iter->second;
575  }
576  delete output;
577  std::for_each(variables.begin(), variables.end(),
579 }
580 
582 {
583  for(std::vector<AtomicId>::const_iterator iter =
584  this->processors.begin();
585  iter != this->processors.end(); iter++) {
586  std::map<AtomicId, Source*>::const_iterator pos =
587  sources.find(*iter);
588  assert(pos != sources.end());
590  dynamic_cast<TrainProcessor*>(pos->second);
591  assert(source);
592 
593  if (source->load())
594  edm::LogInfo("MVATrainer")
595  << source->getId() << " configuration for \""
596  << (const char*)source->getName()
597  << "\" loaded from file.";
598  }
599 }
600 
602 {
603  doCleanup = false;
604 
605  for(std::vector<AtomicId>::const_iterator iter =
606  this->processors.begin();
607  iter != this->processors.end(); iter++) {
608  std::map<AtomicId, Source*>::const_iterator pos =
609  sources.find(*iter);
610  assert(pos != sources.end());
612  dynamic_cast<TrainProcessor*>(pos->second);
613  assert(source);
614 
615  if (source->isTrained())
616  source->save();
617  }
618 }
619 
620 void MVATrainer::makeProcessor(DOMElement *elem, AtomicId id, const char *name)
621 {
622  DOMElement *xmlInput = nullptr;
623  DOMElement *xmlConfig = nullptr;
624  DOMElement *xmlOutput = nullptr;
625  DOMElement *xmlData = nullptr;
626 
627  static struct NameExpect {
628  const char *tag;
629  bool mandatory;
630  DOMElement **elem;
631  } const expect[] = {
632  { "input", true, &xmlInput },
633  { "config", true, &xmlConfig },
634  { "output", true, &xmlOutput },
635  { "data", false, &xmlData },
636  { nullptr, }
637  };
638 
639  const NameExpect *cur = expect;
640  for(DOMNode *node = elem->getFirstChild();
641  node; node = node->getNextSibling()) {
642  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
643  continue;
644 
645  std::string tag = XMLSimpleStr(node->getNodeName());
646  DOMElement *elem = static_cast<DOMElement*>(node);
647 
648  if (!cur->tag)
649  throw cms::Exception("MVATrainer")
650  << "Superfluous tag " << tag
651  << "encountered in processor." << std::endl;
652  else if (tag != cur->tag && cur->mandatory)
653  throw cms::Exception("MVATrainer")
654  << "Expected tag " << cur->tag << ", got "
655  << tag << " instead in processor."
656  << std::endl;
657  else if (tag != cur->tag) {
658  cur++;
659  continue;
660  }
661  *(cur++)->elem = elem;
662  }
663 
664  while(cur->tag && !cur->mandatory)
665  cur++;
666  if (cur->tag)
667  throw cms::Exception("MVATrainer")
668  << "Unexpected end of processor configuration, "
669  << "expected tag " << cur->tag << "." << std::endl;
670 
671  std::unique_ptr<TrainProcessor> proc(
672  TrainProcessor::create(name, &id, this));
673  if (!proc.get())
674  throw cms::Exception("MVATrainer")
675  << "Variable processor trainer " << name
676  << " could not be instantiated. Most likely because"
677  " the trainer plugin for \"" << name << "\""
678  " does not exist." << std::endl;
679 
680  if (sources.find(id) != sources.end())
681  throw cms::Exception("MVATrainer")
682  << "Duplicate variable processor id "
683  << (const char*)id << "."
684  << std::endl;
685 
686  fillInputVars(proc->getInputs(), xmlInput);
687  fillOutputVars(proc->getOutputs(), proc.get(), xmlOutput);
688 
689  edm::LogInfo("MVATrainer")
690  << "Configuring " << (const char*)proc->getId()
691  << " \"" << (const char*)proc->getName() << "\".";
692  proc->configure(xmlConfig);
693 
694  sources.insert(std::make_pair(id, proc.release()));
695  processors.push_back(id);
696 }
697 
699  const std::string &ext,
700  const std::string &arg) const
701 {
702  std::string arg_ = !arg.empty() ? ("_" + arg) : "";
703  return stdStringPrintf(trainFileMask.c_str(),
704  (const char*)proc->getName(),
705  arg_.c_str(), ext.c_str());
706 }
707 
709 {
710  if (!doMonitoring)
711  return nullptr;
712 
713  if (!monitoring.get()) {
716  "monitoring", "", "root");
717  monitoring.reset(new TrainerMonitoring(fileName));
718  }
719 
720  return monitoring->book(name);
721 }
722 
724 {
725  std::map<AtomicId, Source*>::const_iterator pos = sources.find(source);
726  if (pos == sources.end())
727  return nullptr;
728 
729  return pos->second->getOutput(name);
730 }
731 
733  Variable::Flags flags)
734 {
735  SourceVariable *var = getVariable(source->getName(), name);
736  if (var)
737  return nullptr;
738 
739  var = new SourceVariable(source, name, flags);
740  variables.push_back(var);
741  return var;
742 }
743 
746 {
747  std::vector<SourceVariable*> tmp;
748  SourceVariable *target = nullptr;
749  SourceVariable *weight = nullptr;
750 
751  for(DOMNode *node = xml->getFirstChild(); node;
752  node = node->getNextSibling()) {
753  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
754  continue;
755 
756  if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
757  throw cms::Exception("MVATrainer")
758  << "Invalid input variable node." << std::endl;
759 
760  DOMElement *elem = static_cast<DOMElement*>(node);
761 
762  AtomicId source = XMLDocument::readAttribute<std::string>(
763  elem, "source");
764  AtomicId name = XMLDocument::readAttribute<std::string>(
765  elem, "name");
766 
767  SourceVariable *var = getVariable(source, name);
768  if (!var)
769  throw cms::Exception("MVATrainer")
770  << "Input variable " << (const char*)source
771  << ":" << (const char*)name
772  << " not found." << std::endl;
773 
774  if (XMLDocument::readAttribute<bool>(elem, "target", false)) {
775  if (target)
776  throw cms::Exception("MVATrainer")
777  << "Target variable defined twice"
778  << std::endl;
779  target = var;
780  }
781  if (XMLDocument::readAttribute<bool>(elem, "weight", false)) {
782  if (weight)
783  throw cms::Exception("MVATrainer")
784  << "Weight variable defined twice"
785  << std::endl;
786  weight = var;
787  }
788 
789  tmp.push_back(var);
790  }
791 
792  if (!weight) {
793  weight = input->getOutput(kWeightId);
794  assert(weight);
795  tmp.insert(tmp.begin() +
796  (target == input->getOutput(kTargetId)),
797  1, weight);
798  }
799  if (!target) {
800  target = input->getOutput(kTargetId);
801  assert(target);
802  tmp.insert(tmp.begin(), 1, target);
803  }
804 
805  unsigned int n = 0;
806  for(std::vector<SourceVariable*>::const_iterator iter = variables.begin();
807  iter != variables.end(); iter++) {
808  std::vector<SourceVariable*>::const_iterator pos =
809  std::find(tmp.begin(), tmp.end(), *iter);
810  if (pos == tmp.end())
811  continue;
812 
814  if (*iter == target)
816  else if (*iter == weight)
818  else
820 
821  if (vars.append(*iter, magic, pos - tmp.begin())) {
822  AtomicId source = (*iter)->getSource()->getName();
823  AtomicId name = (*iter)->getName();
824  throw cms::Exception("MVATrainer")
825  << "Input variable " << (const char*)source
826  << ":" << (const char*)name
827  << " defined twice." << std::endl;
828  }
829 
830  n++;
831  }
832 
833  assert(tmp.size() == n);
834 }
835 
838 {
839  for(DOMNode *node = xml->getFirstChild(); node;
840  node = node->getNextSibling()) {
841  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
842  continue;
843 
844  if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
845  throw cms::Exception("MVATrainer")
846  << "Invalid output variable node."
847  << std::endl;
848 
849  DOMElement *elem = static_cast<DOMElement*>(node);
850 
851  AtomicId name = XMLDocument::readAttribute<std::string>(
852  elem, "name");
853  if (!name)
854  throw cms::Exception("MVATrainer")
855  << "Output variable tag missing name."
856  << std::endl;
857  if (isMagic(name))
858  throw cms::Exception("MVATrainer")
859  << "Cannot use magic variable names in output."
860  << std::endl;
861 
863 
864  if (XMLDocument::readAttribute<bool>(elem, "optional", true))
866  (flags | Variable::FLAG_OPTIONAL);
867 
868  if (XMLDocument::readAttribute<bool>(elem, "multiple", true))
870  (flags | Variable::FLAG_MULTIPLE);
871 
872  SourceVariable *var = createVariable(source, name, flags);
873  if (!var || vars.append(var))
874  throw cms::Exception("MVATrainer")
875  << "Output variable "
876  << (const char*)source->getName()
877  << ":" << (const char*)name
878  << " defined twice." << std::endl;
879  }
880 }
881 
882 void
884  const std::vector<CalibratedProcessor> &procs,
885  bool withTarget) const
886 {
887  std::map<SourceVariable*, unsigned int> vars;
888  unsigned int size = 0;
889 
890  MVATrainerComputer *trainCalib =
891  dynamic_cast<MVATrainerComputer*>(calib);
892 
893  for(unsigned int i = 0;
894  i < input->getOutputs().size(true); i++) {
895  if (i < 2 && !withTarget)
896  continue;
897 
899  vars[var] = size++;
900 
901  Calibration::Variable calibVar;
902  calibVar.name = (const char*)var->getName();
903  calib->inputSet.push_back(calibVar);
904  if (trainCalib)
905  trainCalib->addFlag(var->getFlags());
906  }
907 
908  for(std::vector<CalibratedProcessor>::const_iterator iter =
909  procs.begin(); iter != procs.end(); iter++) {
910  bool isInterceptor = dynamic_cast<BaseInterceptor*>(
911  iter->calib) != nullptr;
912 
913  BitSet inputSet(size);
914 
915  unsigned int last = 0;
916  std::vector<SourceVariable*> inoutVars;
917  if (iter->processor)
918  inoutVars = iter->processor->getInputs().get(
919  isInterceptor);
920  for(std::vector<SourceVariable*>::const_iterator iter2 =
921  inoutVars.begin(); iter2 != inoutVars.end(); iter2++) {
923  unsigned int>::const_iterator pos =
924  vars.find(*iter2);
925 
926  assert(pos != vars.end());
927 
928  if (pos->second < last)
929  throw cms::Exception("MVATrainer")
930  << "Input variables not declared "
931  "in order of appearance in \""
932  << (const char*)iter->processor->getName()
933  << "\"." << std::endl;
934 
935  inputSet[last = pos->second] = true;
936  }
937 
938  assert(!isInterceptor || withTarget);
939 
940  iter->calib->inputVars = Calibration::convert(inputSet);
941 
942  calib->output = size;
943 
944  if (isInterceptor) {
945  size++;
946  continue;
947  }
948 
949  calib->addProcessor(iter->calib);
950 
951  inoutVars = iter->processor->getOutputs().get();
952  for(std::vector<SourceVariable*>::const_iterator iter =
953  inoutVars.begin(); iter != inoutVars.end(); iter++) {
954 
955  vars[*iter] = size++;
956  }
957  }
958 
959  if (output->getInputs().size() != 1)
960  throw cms::Exception("MVATrainer")
961  << "Exactly one output variable has to be specified."
962  << std::endl;
963 
964  SourceVariable *outVar = output->getInputs().get()[0];
965  std::map<SourceVariable*, unsigned int>::const_iterator pos =
966  vars.find(outVar);
967  if (pos != vars.end())
968  calib->output = pos->second;
969 }
970 
973  const AtomicId *train) const
974 {
975  std::map<AtomicId, TrainInterceptor*> interceptors;
976  std::vector<MVATrainerComputer::Interceptor> baseInterceptors;
977  std::vector<CalibratedProcessor> processors;
978 
979  BaseInterceptor *interceptor = new InitInterceptor;
980  baseInterceptors.push_back(std::make_pair(0, interceptor));
981  processors.push_back(CalibratedProcessor(nullptr, interceptor));
982 
983  for(const AtomicId *iter = train; *iter; iter++) {
985  if (*iter == kOutputId)
986  source = output;
987  else {
988  std::map<AtomicId, Source*>::const_iterator pos =
989  sources.find(*iter);
990  assert(pos != sources.end());
991  source = dynamic_cast<TrainProcessor*>(pos->second);
992  }
993  assert(source);
994 
995  interceptors[*iter] = new TrainInterceptor(source);
996  }
997 
998  auto_cleaner<Calibration::VarProcessor> autoClean;
999 
1000  std::set<AtomicId> done;
1001  for(const AtomicId *iter = compute; *iter; iter++) {
1002  if (done.erase(*iter))
1003  continue;
1004 
1005  std::map<AtomicId, Source*>::const_iterator pos =
1006  sources.find(*iter);
1007  assert(pos != sources.end());
1009  dynamic_cast<TrainProcessor*>(pos->second);
1010  assert(source);
1011  assert(source->isTrained());
1012 
1013  Calibration::VarProcessor *proc = source->getCalibration();
1014  if (!proc)
1015  continue;
1016 
1017  autoClean.add(proc);
1018  processors.push_back(CalibratedProcessor(source, proc));
1019 
1021  dynamic_cast<Calibration::ProcForeach*>(proc);
1022  if (looper) {
1023  std::vector<AtomicId>::const_iterator pos2 =
1024  std::find(this->processors.begin(),
1025  this->processors.end(), *iter);
1026  assert(pos2 != this->processors.end());
1027  ++pos2;
1028  unsigned int n = 0;
1029  for(int i = 0; i < (int)looper->nProcs; ++i, ++pos2) {
1030  assert(pos2 != this->processors.end());
1031 
1032  const AtomicId *iter2 = compute;
1033  while(*iter2) {
1034  if (*iter2 == *pos2)
1035  break;
1036  iter2++;
1037  }
1038 
1039  if (*iter2) {
1040  n++;
1041  done.insert(*iter2);
1042  pos = sources.find(*iter2);
1043  assert(pos != sources.end());
1044  TrainProcessor *source =
1045  dynamic_cast<TrainProcessor*>(
1046  pos->second);
1047  assert(source);
1048  assert(source->isTrained());
1049 
1050  proc = source->getCalibration();
1051  if (proc) {
1052  autoClean.add(proc);
1053  processors.push_back(
1055  source, proc));
1056  }
1057  }
1058 
1059  std::map<AtomicId, TrainInterceptor*>::iterator
1060  pos3 = interceptors.find(*pos2);
1061  if (pos3 != interceptors.end()) {
1062  n++;
1063  baseInterceptors.push_back(
1064  std::make_pair(processors.size(),
1065  pos3->second));
1066  processors.push_back(
1068  pos3->second->getProcessor(),
1069  pos3->second));
1070  interceptors.erase(pos3);
1071  }
1072  }
1073 
1074  looper->nProcs = n;
1075  if (!n) {
1076  baseInterceptors.pop_back();
1077  processors.pop_back();
1078  }
1079  }
1080  }
1081 
1082  for(std::map<AtomicId, TrainInterceptor*>::const_iterator iter =
1083  interceptors.begin(); iter != interceptors.end(); ++iter) {
1084 
1085  TrainProcessor *proc = iter->second->getProcessor();
1086  baseInterceptors.push_back(std::make_pair(processors.size(),
1087  iter->second));
1088  processors.push_back(CalibratedProcessor(proc, iter->second));
1089  }
1090 
1091  std::unique_ptr<Calibration::MVAComputer> calib(
1092  new MVATrainerComputer(baseInterceptors, doAutoSave,
1094 
1095  connectProcessors(calib.get(), processors, true);
1096 
1097  return calib.release();
1098 }
1099 
1101 {
1102  MVATrainerComputer *calib =
1103  dynamic_cast<MVATrainerComputer*>(trainCalibration);
1104 
1105  if (!calib)
1106  throw cms::Exception("MVATrainer")
1107  << "Invalid training calibration passed to "
1108  "doneTraining()" << std::endl;
1109 
1110  calib->done();
1111 }
1112 
1113 std::vector<AtomicId> MVATrainer::findFinalProcessors() const
1114 {
1115  std::set<Source*> toCheck;
1116  toCheck.insert(output);
1117 
1118  std::set<Source*> done;
1119  while(!toCheck.empty()) {
1120  Source *source = *toCheck.begin();
1121  toCheck.erase(toCheck.begin());
1122 
1123  std::vector<SourceVariable*> inputs = source->inputs.get();
1124  for(std::vector<SourceVariable*>::const_iterator iter =
1125  inputs.begin(); iter != inputs.end(); ++iter) {
1126  source = (*iter)->getSource();
1127  if (done.insert(source).second)
1128  toCheck.insert(source);
1129  }
1130  }
1131 
1132  std::vector<AtomicId> result;
1133  for(std::vector<AtomicId>::const_iterator iter = processors.begin();
1134  iter != processors.end(); ++iter) {
1135  std::map<AtomicId, Source*>::const_iterator pos =
1136  sources.find(*iter);
1137  if (pos != sources.end() && done.count(pos->second))
1138  result.push_back(*iter);
1139  }
1140 
1141  return result;
1142 }
1143 
1145 {
1146  std::vector<CalibratedProcessor> processors;
1147 
1148  std::unique_ptr<Calibration::MVAComputer> calib(
1150 
1151  std::vector<AtomicId> used = findFinalProcessors();
1152  for(std::vector<AtomicId>::const_iterator iter = used.begin();
1153  iter != used.end(); iter++) {
1154  std::map<AtomicId, Source*>::const_iterator pos =
1155  sources.find(*iter);
1156  assert(pos != sources.end());
1158  dynamic_cast<TrainProcessor*>(pos->second);
1159  assert(source);
1160  if (!source->isTrained())
1161  return nullptr;
1162 
1163  Calibration::VarProcessor *proc = source->getCalibration();
1164  if (!proc)
1165  continue;
1166 
1167  Calibration::ProcForeach *foreach =
1168  dynamic_cast<Calibration::ProcForeach*>(proc);
1169  if (foreach) {
1170  std::vector<AtomicId>::const_iterator begin =
1171  std::find(this->processors.begin(),
1172  this->processors.end(), *iter);
1173  assert(this->processors.end() - begin >
1174  (int)(foreach->nProcs + 1));
1175  ++begin;
1176  std::vector<AtomicId>::const_iterator end =
1177  begin + foreach->nProcs;
1178  foreach->nProcs = 0;
1179  for(std::vector<AtomicId>::const_iterator iter2 =
1180  iter; iter2 != used.end(); ++iter2)
1181  if (std::find(begin, end, *iter2) != end)
1182  foreach->nProcs++;
1183  }
1184 
1185  processors.push_back(CalibratedProcessor(source, proc));
1186  }
1187 
1188  connectProcessors(calib.get(), processors, false);
1189 
1190  return calib.release();
1191 }
1192 
1193 void MVATrainer::findUntrainedComputers(std::vector<AtomicId> &compute,
1194  std::vector<AtomicId> &train) const
1195 {
1196  compute.clear();
1197  train.clear();
1198 
1199  std::set<Source*> trainedSources;
1200  trainedSources.insert(input);
1201 
1202  for(std::vector<AtomicId>::const_iterator iter =
1203  processors.begin(); iter != processors.end(); iter++) {
1204  std::map<AtomicId, Source*>::const_iterator pos =
1205  sources.find(*iter);
1206  assert(pos != sources.end());
1207  TrainProcessor *proc =
1208  dynamic_cast<TrainProcessor*>(pos->second);
1209  assert(proc);
1210 
1211  bool trainedDeps = true;
1212  std::vector<SourceVariable*> inputVars =
1213  proc->getInputs().get();
1214  for(std::vector<SourceVariable*>::const_iterator iter2 =
1215  inputVars.begin(); iter2 != inputVars.end(); iter2++) {
1216  if (trainedSources.find((*iter2)->getSource())
1217  == trainedSources.end()) {
1218  trainedDeps = false;
1219  break;
1220  }
1221  }
1222 
1223  if (!trainedDeps)
1224  continue;
1225 
1226  if (proc->isTrained()) {
1227  trainedSources.insert(proc);
1228  compute.push_back(proc->getName());
1229  } else
1230  train.push_back(proc->getName());
1231  }
1232 
1233  if (doMonitoring && !output->isTrained() &&
1234  trainedSources.find(output->getInputs().get()[0]->getSource())
1235  != trainedSources.end())
1236  train.push_back(kOutputId);
1237 }
1238 
1240 {
1241  std::vector<AtomicId> compute, train;
1242  findUntrainedComputers(compute, train);
1243 
1244  if (train.empty())
1245  return nullptr;
1246 
1247  compute.push_back(nullptr);
1248  train.push_back(nullptr);
1249 
1250  return makeTrainCalibration(&compute.front(), &train.front());
1251 }
1252 
1253 } // namespace PhysicsTools
size
Write out results.
static std::string escape(const std::string &in)
Definition: MVATrainer.cc:403
unsigned int nConfigured
Definition: MVATrainer.cc:136
virtual std::vector< VarProcessor * > getProcessors() const
Definition: MVAComputer.cc:177
bool isTrained() const
Definition: Source.h:24
TrainProcessor *const proc
Definition: MVATrainer.cc:101
#define nullptr
int init
Definition: HydjetWrapper.h:67
std::vector< Variable::Flags > flags
Definition: MVATrainer.cc:135
Definition: weight.py:1
SourceVariable * getVariable(AtomicId source, AtomicId name) const
Definition: MVATrainer.cc:723
#define XERCES_CPP_NAMESPACE_QUALIFIER
Definition: LHERunInfo.h:16
TRandom random
Definition: MVATrainer.cc:138
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:20
A arg
Definition: Factorize.h:38
static const AtomicId kTargetId
Definition: MVATrainer.h:59
SourceVariable * getOutput(AtomicId name) const
Definition: Source.h:21
const SourceVariableSet & getInputs() const
Definition: Source.h:26
void fillOutputVars(SourceVariableSet &vars, Source *source, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
Definition: MVATrainer.cc:836
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:31
unsigned int weightIdx
Definition: MVATrainer.cc:99
static std::string const input
Definition: EdmProvDump.cc:48
Definition: looper.py:1
bool doAutoSave
Definition: MVATrainer.cc:137
static const AtomicId kWeightId
Definition: MVATrainer.h:60
Calibration::MVAComputer * makeTrainCalibration(const AtomicId *compute, const AtomicId *train) const
Definition: MVATrainer.cc:972
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
Main interface class to the generic discriminator computer framework.
Definition: MVAComputer.h:39
const AtomicId getName() const
Definition: Variable.h:143
std::vector< AtomicId > processors
Definition: MVATrainer.h:102
Flags getFlags() const
Definition: Variable.h:144
const SourceVariableSet & getOutputs() const
Definition: Source.h:27
std::vector< AtomicId > findFinalProcessors() const
Definition: MVATrainer.cc:1113
A compact container for storing single bits.
Definition: BitSet.h:29
std::vector< SourceVariable * > get(bool withMagic=false) const
void addProcessor(const VarProcessor *proc)
Definition: MVAComputer.cc:182
#define end
Definition: vmac.h:39
unsigned int targetIdx
Definition: MVATrainer.cc:98
std::string trainFileName(const TrainProcessor *proc, const std::string &ext, const std::string &arg="") const
Definition: MVATrainer.cc:698
SourceVariableSet inputs
Definition: Source.h:39
void connectProcessors(Calibration::MVAComputer *calib, const std::vector< CalibratedProcessor > &procs, bool withTarget) const
Definition: MVATrainer.cc:883
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:19
std::vector< T * > clean
Definition: MVATrainer.cc:154
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
static Base_t * create(const char *name, const CalibBase_t *calib, Parent_t *parent=0)
void add(std::map< std::string, TH1 * > &h, TH1 *hist)
bool splitResult
Definition: MVATrainer.cc:140
void doneTraining(Calibration::MVAComputer *trainCalibration) const
Definition: MVATrainer.cc:1100
std::unique_ptr< TrainerMonitoring > monitoring
Definition: MVATrainer.h:106
TrainerMonitoring::Module * bookMonitor(const std::string &name)
Definition: MVATrainer.cc:708
std::vector< Interceptor > interceptors
Definition: MVATrainer.cc:134
size_type size(bool withMagic=false) const
bool append(SourceVariable *var, Magic magic=kRegular, int offset=-1)
std::unique_ptr< XMLDocument > xml
Definition: MVATrainer.h:107
Calibration::MVAComputer * getCalibration() const
Definition: MVATrainer.cc:1144
def compute(min, max)
std::vector< SourceVariable * > variables
Definition: MVATrainer.h:101
std::vector< Variable > inputSet
Definition: MVAComputer.h:234
void fillInputVars(SourceVariableSet &vars, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
Definition: MVATrainer.cc:744
SourceVariable * find(AtomicId name) const
std::vector< std::vector< double > > tmp
Definition: MVATrainer.cc:100
MVATrainer(const std::string &fileName, bool useXSLT=false, const char *styleSheet=0)
Definition: MVATrainer.cc:420
#define begin
Definition: vmac.h:32
AtomicId getName() const
Definition: Source.h:19
void makeProcessor(XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *elem, AtomicId id, const char *name)
Definition: MVATrainer.cc:620
static std::vector< std::string > split(const std::string line, char delim)
Definition: MLP.cc:18
std::string fullPath() const
Definition: FileInPath.cc:163
std::string trainFileMask
Definition: MVATrainer.h:108
PhysicsTools::BitSet convert(const BitSet &bitSet)
constructs BitSet container from persistent representation
Definition: BitSet.cc:38
void findUntrainedComputers(std::vector< AtomicId > &compute, std::vector< AtomicId > &train) const
Definition: MVATrainer.cc:1193
TrainProcessor * output
Definition: MVATrainer.h:104
vars
Definition: DeepTauId.cc:77
Calibration::MVAComputer * getTrainCalibration() const
Definition: MVATrainer.cc:1239
Definition: memstream.h:15
long double T
static std::string const source
Definition: EdmProvDump.cc:47
save
Definition: cuy.py:1165
static const AtomicId kOutputId("__OUTPUT__")
static std::string stdStringVPrintf(const char *format, std::va_list va)
Definition: MVATrainer.cc:158
static std::string stdStringPrintf(const char *format,...)
Definition: MVATrainer.cc:181
static bool isMagic(AtomicId id)
Definition: MVATrainer.cc:396
SourceVariable * createVariable(Source *source, AtomicId name, Variable::Flags flags)
Definition: MVATrainer.cc:732