CMS 3D CMS Logo

List of all members | Classes | Public Member Functions | Static Public Attributes | Private Member Functions | Private Attributes
PhysicsTools::MVATrainer Class Reference

#include <MVATrainer.h>

Classes

struct  CalibratedProcessor
 

Public Member Functions

TrainerMonitoring::ModulebookMonitor (const std::string &name)
 
void doneTraining (Calibration::MVAComputer *trainCalibration) const
 
Calibration::MVAComputergetCalibration () const
 
const std::string & getName () const
 
Calibration::MVAComputergetTrainCalibration () const
 
void loadState ()
 
 MVATrainer (const std::string &fileName, bool useXSLT=false, const char *styleSheet=0)
 
void saveState ()
 
void setAutoSave (bool autoSave)
 
void setCleanup (bool cleanup)
 
void setCrossValidation (double split)
 
void setMonitoring (bool monitoring)
 
void setRandomSeed (UInt_t seed)
 
std::string trainFileName (const TrainProcessor *proc, const std::string &ext, const std::string &arg="") const
 
 ~MVATrainer ()
 

Static Public Attributes

static const AtomicId kTargetId
 
static const AtomicId kWeightId
 

Private Member Functions

void connectProcessors (Calibration::MVAComputer *calib, const std::vector< CalibratedProcessor > &procs, bool withTarget) const
 
SourceVariablecreateVariable (Source *source, AtomicId name, Variable::Flags flags)
 
void fillInputVars (SourceVariableSet &vars, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
 
void fillOutputVars (SourceVariableSet &vars, Source *source, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
 
std::vector< AtomicIdfindFinalProcessors () const
 
void findUntrainedComputers (std::vector< AtomicId > &compute, std::vector< AtomicId > &train) const
 
SourceVariablegetVariable (AtomicId source, AtomicId name) const
 
void makeProcessor (XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *elem, AtomicId id, const char *name)
 
Calibration::MVAComputermakeTrainCalibration (const AtomicId *compute, const AtomicId *train) const
 

Private Attributes

double crossValidation
 
bool doAutoSave
 
bool doCleanup
 
bool doMonitoring
 
Sourceinput
 
std::unique_ptr< TrainerMonitoringmonitoring
 
std::string name
 
TrainProcessoroutput
 
std::vector< AtomicIdprocessors
 
UInt_t randomSeed
 
std::map< AtomicId, Source * > sources
 
std::string trainFileMask
 
std::vector< SourceVariable * > variables
 
std::unique_ptr< XMLDocumentxml
 

Detailed Description

Definition at line 27 of file MVATrainer.h.

Constructor & Destructor Documentation

PhysicsTools::MVATrainer::MVATrainer ( const std::string &  fileName,
bool  useXSLT = false,
const char *  styleSheet = 0 
)

Definition at line 420 of file MVATrainer.cc.

References ws_sso_content_reader::content, createVariable(), HTMLExport::elem(), PhysicsTools::escape(), Exception, fillInputVars(), fillOutputVars(), PhysicsTools::Variable::FLAG_NONE, PhysicsTools::Variable::FLAG_OPTIONAL, edm::FileInPath::fullPath(), PhysicsTools::Source::getInputs(), input, PhysicsTools::SourceVariableSet::kTarget, kTargetId, PhysicsTools::SourceVariableSet::kWeight, kWeightId, makeProcessor(), name, output, PFJetToCaloProducer_cfi::Source, sources, AlCaHLTBitMon_QueryRunRegistry::string, trainFileMask, and xml.

421  :
422  input(nullptr), output(nullptr), name("MVATrainer"),
423  doAutoSave(true), doCleanup(false),
424  doMonitoring(false), randomSeed(65539), crossValidation(0.0)
425 {
426  if (useXSLT) {
427  std::string sheet;
428  if (!styleSheet)
429  sheet = edm::FileInPath(
430  "PhysicsTools/MVATrainer/data/MVATrainer.xsl")
431  .fullPath();
432  else
433  sheet = styleSheet;
434 
435  std::string preproc = "xsltproc --xinclude " + escape(sheet) +
436  " " + escape(fileName);
437  xml.reset(new XMLDocument(fileName, preproc));
438  } else
439  xml.reset(new XMLDocument(fileName));
440 
441  DOMNode *node = xml->getRootNode();
442 
443  if (std::strcmp(XMLSimpleStr(node->getNodeName()), "MVATrainer") != 0)
444  throw cms::Exception("MVATrainer")
445  << "Invalid XML root node." << std::endl;
446 
447  enum State {
448  STATE_GENERAL,
449  STATE_FIRST,
450  STATE_MIDDLE,
451  STATE_LAST
452  } state = STATE_GENERAL;
453 
454  for(node = node->getFirstChild();
455  node; node = node->getNextSibling()) {
456  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
457  continue;
458 
459  std::string name = XMLSimpleStr(node->getNodeName());
460  DOMElement *elem = static_cast<DOMElement*>(node);
461 
462  switch(state) {
463  case STATE_GENERAL: {
464  if (name != "general")
465  throw cms::Exception("MVATrainer")
466  << "Expected general config as first "
467  "tag." << std::endl;
468 
469  for(DOMNode *subNode = elem->getFirstChild();
470  subNode; subNode = subNode->getNextSibling()) {
471  if (subNode->getNodeType() !=
472  DOMNode::ELEMENT_NODE)
473  continue;
474 
475  if (std::strcmp(XMLSimpleStr(
476  subNode->getNodeName()), "option") != 0)
477  throw cms::Exception("MVATrainer")
478  << "Expected option tag."
479  << std::endl;
480 
481  elem = static_cast<DOMElement*>(subNode);
482  name = XMLDocument::readAttribute<std::string>(
483  elem, "name");
485  elem->getTextContent());
486 
487  if (name == "id")
488  this->name = content;
489  else if (name == "trainfiles")
491  else
492  throw cms::Exception("MVATrainer")
493  << "Unknown option \""
494  << name << "\"." << std::endl;
495  }
496 
497  state = STATE_FIRST;
498  } break;
499  case STATE_FIRST: {
500  if (name != "input")
501  throw cms::Exception("MVATrainer")
502  << "Expected input config as second "
503  "tag." << std::endl;
504 
505  AtomicId id = XMLDocument::readAttribute<std::string>(
506  elem, "id");
507  input = new Source(id, true);
508  input->getOutputs().append(
509  createVariable(input, kTargetId,
512  input->getOutputs().append(
513  createVariable(input, kWeightId,
516  sources.insert(std::make_pair(id, input));
517  fillOutputVars(input->getOutputs(), input, elem);
518 
519  state = STATE_MIDDLE;
520  } break;
521  case STATE_MIDDLE: {
522  if (name == "output") {
523  AtomicId zero;
524  output = new TrainProcessor("output",
525  &zero, this);
527  state = STATE_LAST;
528  continue;
529  } else if (name != "processor")
530  throw cms::Exception("MVATrainer")
531  << "Unexpected tag after input "
532  "config." << std::endl;
533 
534  AtomicId id = XMLDocument::readAttribute<std::string>(
535  elem, "id");
536  std::string name =
537  XMLDocument::readAttribute<std::string>(
538  elem, "name");
539 
540  makeProcessor(elem, id, name.c_str());
541  } break;
542  case STATE_LAST:
543  throw cms::Exception("MVATrainer")
544  << "Unexpected tag found after output."
545  << std::endl;
546  break;
547  }
548  }
549 
550  if (state == STATE_FIRST)
551  throw cms::Exception("MVATrainer")
552  << "Expected input variable config." << std::endl;
553  else if (state == STATE_MIDDLE)
554  throw cms::Exception("MVATrainer")
555  << "Expected output variable config." << std::endl;
556 
557  if (trainFileMask.empty())
558  trainFileMask = this->name + "_%s%s.%s";
559 }
static std::string escape(const std::string &in)
Definition: MVATrainer.cc:403
static const AtomicId kTargetId
Definition: MVATrainer.h:59
const SourceVariableSet & getInputs() const
Definition: Source.h:26
void fillOutputVars(SourceVariableSet &vars, Source *source, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
Definition: MVATrainer.cc:836
static const AtomicId kWeightId
Definition: MVATrainer.h:60
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:18
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
std::unique_ptr< XMLDocument > xml
Definition: MVATrainer.h:107
void fillInputVars(SourceVariableSet &vars, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
Definition: MVATrainer.cc:744
void makeProcessor(XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *elem, AtomicId id, const char *name)
Definition: MVATrainer.cc:620
std::string fullPath() const
Definition: FileInPath.cc:197
std::string trainFileMask
Definition: MVATrainer.h:108
TrainProcessor * output
Definition: MVATrainer.h:104
SourceVariable * createVariable(Source *source, AtomicId name, Variable::Flags flags)
Definition: MVATrainer.cc:732
PhysicsTools::MVATrainer::~MVATrainer ( )

Definition at line 561 of file MVATrainer.cc.

References PhysicsTools::TrainProcessor::cleanup(), doCleanup, monitoring, output, proc, sources, and variables.

562 {
563  if (monitoring.get())
564  monitoring->write();
565 
566  for(std::map<AtomicId, Source*>::const_iterator iter = sources.begin();
567  iter != sources.end(); iter++) {
568  TrainProcessor *proc =
569  dynamic_cast<TrainProcessor*>(iter->second);
570 
571  if (proc && doCleanup)
572  proc->cleanup();
573 
574  delete iter->second;
575  }
576  delete output;
577  std::for_each(variables.begin(), variables.end(),
579 }
TrainProcessor *const proc
Definition: MVATrainer.cc:101
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
std::unique_ptr< TrainerMonitoring > monitoring
Definition: MVATrainer.h:106
std::vector< SourceVariable * > variables
Definition: MVATrainer.h:101
TrainProcessor * output
Definition: MVATrainer.h:104

Member Function Documentation

TrainerMonitoring::Module * PhysicsTools::MVATrainer::bookMonitor ( const std::string &  name)

Definition at line 708 of file MVATrainer.cc.

References doMonitoring, MillePedeFileConverter_cfg::fileName, monitoring, PhysicsTools::stdStringPrintf(), AlCaHLTBitMon_QueryRunRegistry::string, and trainFileMask.

Referenced by PhysicsTools::TrainProcessor::doTrainBegin(), and getName().

709 {
710  if (!doMonitoring)
711  return nullptr;
712 
713  if (!monitoring.get()) {
716  "monitoring", "", "root");
717  monitoring.reset(new TrainerMonitoring(fileName));
718  }
719 
720  return monitoring->book(name);
721 }
std::unique_ptr< TrainerMonitoring > monitoring
Definition: MVATrainer.h:106
std::string trainFileMask
Definition: MVATrainer.h:108
static std::string stdStringPrintf(const char *format,...)
Definition: MVATrainer.cc:181
void PhysicsTools::MVATrainer::connectProcessors ( Calibration::MVAComputer calib,
const std::vector< CalibratedProcessor > &  procs,
bool  withTarget 
) const
private

Definition at line 883 of file MVATrainer.cc.

References PhysicsTools::Calibration::MVAComputer::addProcessor(), calib, PhysicsTools::Calibration::convert(), Exception, PhysicsTools::SourceVariableSet::get(), PhysicsTools::Variable::getFlags(), PhysicsTools::Source::getInputs(), PhysicsTools::Variable::getName(), PhysicsTools::Source::getOutputs(), mps_fire::i, input, PhysicsTools::Calibration::MVAComputer::inputSet, plotBeamSpotDB::last, genParticles_cff::map, PhysicsTools::Calibration::Variable::name, output, PhysicsTools::Calibration::MVAComputer::output, PhysicsTools::SourceVariableSet::size(), findQualityFiles::size, JetChargeProducer_cfi::var, and variables.

Referenced by getCalibration(), and makeTrainCalibration().

886 {
887  std::map<SourceVariable*, unsigned int> vars;
888  unsigned int size = 0;
889 
890  MVATrainerComputer *trainCalib =
891  dynamic_cast<MVATrainerComputer*>(calib);
892 
893  for(unsigned int i = 0;
894  i < input->getOutputs().size(true); i++) {
895  if (i < 2 && !withTarget)
896  continue;
897 
898  SourceVariable *var = variables[i];
899  vars[var] = size++;
900 
901  Calibration::Variable calibVar;
902  calibVar.name = (const char*)var->getName();
903  calib->inputSet.push_back(calibVar);
904  if (trainCalib)
905  trainCalib->addFlag(var->getFlags());
906  }
907 
908  for(std::vector<CalibratedProcessor>::const_iterator iter =
909  procs.begin(); iter != procs.end(); iter++) {
910  bool isInterceptor = dynamic_cast<BaseInterceptor*>(
911  iter->calib) != nullptr;
912 
913  BitSet inputSet(size);
914 
915  unsigned int last = 0;
916  std::vector<SourceVariable*> inoutVars;
917  if (iter->processor)
918  inoutVars = iter->processor->getInputs().get(
919  isInterceptor);
920  for(std::vector<SourceVariable*>::const_iterator iter2 =
921  inoutVars.begin(); iter2 != inoutVars.end(); iter2++) {
922  std::map<SourceVariable*,
923  unsigned int>::const_iterator pos =
924  vars.find(*iter2);
925 
926  assert(pos != vars.end());
927 
928  if (pos->second < last)
929  throw cms::Exception("MVATrainer")
930  << "Input variables not declared "
931  "in order of appearance in \""
932  << (const char*)iter->processor->getName()
933  << "\"." << std::endl;
934 
935  inputSet[last = pos->second] = true;
936  }
937 
938  assert(!isInterceptor || withTarget);
939 
940  iter->calib->inputVars = Calibration::convert(inputSet);
941 
942  calib->output = size;
943 
944  if (isInterceptor) {
945  size++;
946  continue;
947  }
948 
949  calib->addProcessor(iter->calib);
950 
951  inoutVars = iter->processor->getOutputs().get();
952  for(std::vector<SourceVariable*>::const_iterator iter =
953  inoutVars.begin(); iter != inoutVars.end(); iter++) {
954 
955  vars[*iter] = size++;
956  }
957  }
958 
959  if (output->getInputs().size() != 1)
960  throw cms::Exception("MVATrainer")
961  << "Exactly one output variable has to be specified."
962  << std::endl;
963 
964  SourceVariable *outVar = output->getInputs().get()[0];
965  std::map<SourceVariable*, unsigned int>::const_iterator pos =
966  vars.find(outVar);
967  if (pos != vars.end())
968  calib->output = pos->second;
969 }
size
Write out results.
const SourceVariableSet & getInputs() const
Definition: Source.h:26
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
const SourceVariableSet & getOutputs() const
Definition: Source.h:27
std::vector< SourceVariable * > get(bool withMagic=false) const
size_type size(bool withMagic=false) const
std::vector< SourceVariable * > variables
Definition: MVATrainer.h:101
PhysicsTools::BitSet convert(const BitSet &bitSet)
constructs BitSet container from persistent representation
Definition: BitSet.cc:38
TrainProcessor * output
Definition: MVATrainer.h:104
SourceVariable * PhysicsTools::MVATrainer::createVariable ( Source source,
AtomicId  name,
Variable::Flags  flags 
)
private

Definition at line 732 of file MVATrainer.cc.

References PhysicsTools::Source::getName(), getVariable(), name, JetChargeProducer_cfi::var, and variables.

Referenced by fillOutputVars(), and MVATrainer().

734 {
735  SourceVariable *var = getVariable(source->getName(), name);
736  if (var)
737  return nullptr;
738 
739  var = new SourceVariable(source, name, flags);
740  variables.push_back(var);
741  return var;
742 }
std::vector< Variable::Flags > flags
Definition: MVATrainer.cc:135
SourceVariable * getVariable(AtomicId source, AtomicId name) const
Definition: MVATrainer.cc:723
std::vector< SourceVariable * > variables
Definition: MVATrainer.h:101
static std::string const source
Definition: EdmProvDump.cc:44
void PhysicsTools::MVATrainer::doneTraining ( Calibration::MVAComputer trainCalibration) const

Definition at line 1100 of file MVATrainer.cc.

References Exception.

Referenced by setCrossValidation().

1101 {
1102  MVATrainerComputer *calib =
1103  dynamic_cast<MVATrainerComputer*>(trainCalibration);
1104 
1105  if (!calib)
1106  throw cms::Exception("MVATrainer")
1107  << "Invalid training calibration passed to "
1108  "doneTraining()" << std::endl;
1109 
1110  calib->done();
1111 }
void PhysicsTools::MVATrainer::fillInputVars ( SourceVariableSet vars,
XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *  xml 
)
private

Definition at line 744 of file MVATrainer.cc.

References PhysicsTools::SourceVariableSet::append(), HTMLExport::elem(), Exception, spr::find(), PhysicsTools::Source::getOutput(), getVariable(), input, PhysicsTools::SourceVariableSet::kRegular, PhysicsTools::SourceVariableSet::kTarget, kTargetId, PhysicsTools::SourceVariableSet::kWeight, kWeightId, gen::n, name, source, edmPickEvents::target, tmp, JetChargeProducer_cfi::var, and variables.

Referenced by makeProcessor(), and MVATrainer().

746 {
747  std::vector<SourceVariable*> tmp;
748  SourceVariable *target = nullptr;
749  SourceVariable *weight = nullptr;
750 
751  for(DOMNode *node = xml->getFirstChild(); node;
752  node = node->getNextSibling()) {
753  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
754  continue;
755 
756  if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
757  throw cms::Exception("MVATrainer")
758  << "Invalid input variable node." << std::endl;
759 
760  DOMElement *elem = static_cast<DOMElement*>(node);
761 
762  AtomicId source = XMLDocument::readAttribute<std::string>(
763  elem, "source");
764  AtomicId name = XMLDocument::readAttribute<std::string>(
765  elem, "name");
766 
767  SourceVariable *var = getVariable(source, name);
768  if (!var)
769  throw cms::Exception("MVATrainer")
770  << "Input variable " << (const char*)source
771  << ":" << (const char*)name
772  << " not found." << std::endl;
773 
774  if (XMLDocument::readAttribute<bool>(elem, "target", false)) {
775  if (target)
776  throw cms::Exception("MVATrainer")
777  << "Target variable defined twice"
778  << std::endl;
779  target = var;
780  }
781  if (XMLDocument::readAttribute<bool>(elem, "weight", false)) {
782  if (weight)
783  throw cms::Exception("MVATrainer")
784  << "Weight variable defined twice"
785  << std::endl;
786  weight = var;
787  }
788 
789  tmp.push_back(var);
790  }
791 
792  if (!weight) {
793  weight = input->getOutput(kWeightId);
794  assert(weight);
795  tmp.insert(tmp.begin() +
796  (target == input->getOutput(kTargetId)),
797  1, weight);
798  }
799  if (!target) {
800  target = input->getOutput(kTargetId);
801  assert(target);
802  tmp.insert(tmp.begin(), 1, target);
803  }
804 
805  unsigned int n = 0;
806  for(std::vector<SourceVariable*>::const_iterator iter = variables.begin();
807  iter != variables.end(); iter++) {
808  std::vector<SourceVariable*>::const_iterator pos =
809  std::find(tmp.begin(), tmp.end(), *iter);
810  if (pos == tmp.end())
811  continue;
812 
814  if (*iter == target)
816  else if (*iter == weight)
818  else
820 
821  if (vars.append(*iter, magic, pos - tmp.begin())) {
822  AtomicId source = (*iter)->getSource()->getName();
823  AtomicId name = (*iter)->getName();
824  throw cms::Exception("MVATrainer")
825  << "Input variable " << (const char*)source
826  << ":" << (const char*)name
827  << " defined twice." << std::endl;
828  }
829 
830  n++;
831  }
832 
833  assert(tmp.size() == n);
834 }
Definition: weight.py:1
SourceVariable * getVariable(AtomicId source, AtomicId name) const
Definition: MVATrainer.cc:723
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:20
static const AtomicId kTargetId
Definition: MVATrainer.h:59
SourceVariable * getOutput(AtomicId name) const
Definition: Source.h:21
static const AtomicId kWeightId
Definition: MVATrainer.h:60
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:18
std::unique_ptr< XMLDocument > xml
Definition: MVATrainer.h:107
std::vector< SourceVariable * > variables
Definition: MVATrainer.h:101
std::vector< std::vector< double > > tmp
Definition: MVATrainer.cc:100
static std::string const source
Definition: EdmProvDump.cc:44
void PhysicsTools::MVATrainer::fillOutputVars ( SourceVariableSet vars,
Source source,
XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *  xml 
)
private

Definition at line 836 of file MVATrainer.cc.

References PhysicsTools::SourceVariableSet::append(), createVariable(), HTMLExport::elem(), Exception, PhysicsTools::Variable::FLAG_MULTIPLE, PhysicsTools::Variable::FLAG_NONE, PhysicsTools::Variable::FLAG_OPTIONAL, PhysicsTools::Source::getName(), PhysicsTools::isMagic(), name, and JetChargeProducer_cfi::var.

Referenced by makeProcessor(), and MVATrainer().

838 {
839  for(DOMNode *node = xml->getFirstChild(); node;
840  node = node->getNextSibling()) {
841  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
842  continue;
843 
844  if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
845  throw cms::Exception("MVATrainer")
846  << "Invalid output variable node."
847  << std::endl;
848 
849  DOMElement *elem = static_cast<DOMElement*>(node);
850 
851  AtomicId name = XMLDocument::readAttribute<std::string>(
852  elem, "name");
853  if (!name)
854  throw cms::Exception("MVATrainer")
855  << "Output variable tag missing name."
856  << std::endl;
857  if (isMagic(name))
858  throw cms::Exception("MVATrainer")
859  << "Cannot use magic variable names in output."
860  << std::endl;
861 
863 
864  if (XMLDocument::readAttribute<bool>(elem, "optional", true))
866  (flags | Variable::FLAG_OPTIONAL);
867 
868  if (XMLDocument::readAttribute<bool>(elem, "multiple", true))
870  (flags | Variable::FLAG_MULTIPLE);
871 
872  SourceVariable *var = createVariable(source, name, flags);
873  if (!var || vars.append(var))
874  throw cms::Exception("MVATrainer")
875  << "Output variable "
876  << (const char*)source->getName()
877  << ":" << (const char*)name
878  << " defined twice." << std::endl;
879  }
880 }
std::vector< Variable::Flags > flags
Definition: MVATrainer.cc:135
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:18
std::unique_ptr< XMLDocument > xml
Definition: MVATrainer.h:107
static std::string const source
Definition: EdmProvDump.cc:44
static bool isMagic(AtomicId id)
Definition: MVATrainer.cc:396
SourceVariable * createVariable(Source *source, AtomicId name, Variable::Flags flags)
Definition: MVATrainer.cc:732
std::vector< AtomicId > PhysicsTools::MVATrainer::findFinalProcessors ( ) const
private

Definition at line 1113 of file MVATrainer.cc.

References PhysicsTools::SourceVariableSet::get(), PhysicsTools::Source::inputs, PatBasicFWLiteJetAnalyzer_Selector_cfg::inputs, output, processors, mps_fire::result, source, and sources.

Referenced by getCalibration().

1114 {
1115  std::set<Source*> toCheck;
1116  toCheck.insert(output);
1117 
1118  std::set<Source*> done;
1119  while(!toCheck.empty()) {
1120  Source *source = *toCheck.begin();
1121  toCheck.erase(toCheck.begin());
1122 
1123  std::vector<SourceVariable*> inputs = source->inputs.get();
1124  for(std::vector<SourceVariable*>::const_iterator iter =
1125  inputs.begin(); iter != inputs.end(); ++iter) {
1126  source = (*iter)->getSource();
1127  if (done.insert(source).second)
1128  toCheck.insert(source);
1129  }
1130  }
1131 
1132  std::vector<AtomicId> result;
1133  for(std::vector<AtomicId>::const_iterator iter = processors.begin();
1134  iter != processors.end(); ++iter) {
1135  std::map<AtomicId, Source*>::const_iterator pos =
1136  sources.find(*iter);
1137  if (pos != sources.end() && done.count(pos->second))
1138  result.push_back(*iter);
1139  }
1140 
1141  return result;
1142 }
std::vector< AtomicId > processors
Definition: MVATrainer.h:102
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
TrainProcessor * output
Definition: MVATrainer.h:104
static std::string const source
Definition: EdmProvDump.cc:44
void PhysicsTools::MVATrainer::findUntrainedComputers ( std::vector< AtomicId > &  compute,
std::vector< AtomicId > &  train 
) const
private

Definition at line 1193 of file MVATrainer.cc.

References doMonitoring, PhysicsTools::SourceVariableSet::get(), PhysicsTools::Source::getInputs(), input, PhysicsTools::Source::isTrained(), output, proc, processors, and sources.

Referenced by getTrainCalibration().

1195 {
1196  compute.clear();
1197  train.clear();
1198 
1199  std::set<Source*> trainedSources;
1200  trainedSources.insert(input);
1201 
1202  for(std::vector<AtomicId>::const_iterator iter =
1203  processors.begin(); iter != processors.end(); iter++) {
1204  std::map<AtomicId, Source*>::const_iterator pos =
1205  sources.find(*iter);
1206  assert(pos != sources.end());
1207  TrainProcessor *proc =
1208  dynamic_cast<TrainProcessor*>(pos->second);
1209  assert(proc);
1210 
1211  bool trainedDeps = true;
1212  std::vector<SourceVariable*> inputVars =
1213  proc->getInputs().get();
1214  for(std::vector<SourceVariable*>::const_iterator iter2 =
1215  inputVars.begin(); iter2 != inputVars.end(); iter2++) {
1216  if (trainedSources.find((*iter2)->getSource())
1217  == trainedSources.end()) {
1218  trainedDeps = false;
1219  break;
1220  }
1221  }
1222 
1223  if (!trainedDeps)
1224  continue;
1225 
1226  if (proc->isTrained()) {
1227  trainedSources.insert(proc);
1228  compute.push_back(proc->getName());
1229  } else
1230  train.push_back(proc->getName());
1231  }
1232 
1233  if (doMonitoring && !output->isTrained() &&
1234  trainedSources.find(output->getInputs().get()[0]->getSource())
1235  != trainedSources.end())
1236  train.push_back(kOutputId);
1237 }
bool isTrained() const
Definition: Source.h:24
TrainProcessor *const proc
Definition: MVATrainer.cc:101
const SourceVariableSet & getInputs() const
Definition: Source.h:26
std::vector< AtomicId > processors
Definition: MVATrainer.h:102
std::vector< SourceVariable * > get(bool withMagic=false) const
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
def compute(min, max)
TrainProcessor * output
Definition: MVATrainer.h:104
static const AtomicId kOutputId("__OUTPUT__")
Calibration::MVAComputer * PhysicsTools::MVATrainer::getCalibration ( ) const

Definition at line 1144 of file MVATrainer.cc.

References begin, calib, connectProcessors(), end, spr::find(), findFinalProcessors(), createfilelist::int, proc, processors, source, and sources.

Referenced by PhysicsTools::MVATrainerContainerLooperImpl< Record_t >::produce(), setCrossValidation(), and PhysicsTools::TreeTrainer::train().

1145 {
1146  std::vector<CalibratedProcessor> processors;
1147 
1148  std::unique_ptr<Calibration::MVAComputer> calib(
1149  new Calibration::MVAComputer);
1150 
1151  std::vector<AtomicId> used = findFinalProcessors();
1152  for(std::vector<AtomicId>::const_iterator iter = used.begin();
1153  iter != used.end(); iter++) {
1154  std::map<AtomicId, Source*>::const_iterator pos =
1155  sources.find(*iter);
1156  assert(pos != sources.end());
1157  TrainProcessor *source =
1158  dynamic_cast<TrainProcessor*>(pos->second);
1159  assert(source);
1160  if (!source->isTrained())
1161  return nullptr;
1162 
1163  Calibration::VarProcessor *proc = source->getCalibration();
1164  if (!proc)
1165  continue;
1166 
1167  Calibration::ProcForeach *foreach =
1168  dynamic_cast<Calibration::ProcForeach*>(proc);
1169  if (foreach) {
1170  std::vector<AtomicId>::const_iterator begin =
1171  std::find(this->processors.begin(),
1172  this->processors.end(), *iter);
1173  assert(this->processors.end() - begin >
1174  (int)(foreach->nProcs + 1));
1175  ++begin;
1176  std::vector<AtomicId>::const_iterator end =
1177  begin + foreach->nProcs;
1178  foreach->nProcs = 0;
1179  for(std::vector<AtomicId>::const_iterator iter2 =
1180  iter; iter2 != used.end(); ++iter2)
1181  if (std::find(begin, end, *iter2) != end)
1182  foreach->nProcs++;
1183  }
1184 
1185  processors.push_back(CalibratedProcessor(source, proc));
1186  }
1187 
1188  connectProcessors(calib.get(), processors, false);
1189 
1190  return calib.release();
1191 }
TrainProcessor *const proc
Definition: MVATrainer.cc:101
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:20
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
std::vector< AtomicId > processors
Definition: MVATrainer.h:102
std::vector< AtomicId > findFinalProcessors() const
Definition: MVATrainer.cc:1113
#define end
Definition: vmac.h:39
void connectProcessors(Calibration::MVAComputer *calib, const std::vector< CalibratedProcessor > &procs, bool withTarget) const
Definition: MVATrainer.cc:883
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
#define begin
Definition: vmac.h:32
static std::string const source
Definition: EdmProvDump.cc:44
const std::string& PhysicsTools::MVATrainer::getName ( void  ) const
inline

Definition at line 53 of file MVATrainer.h.

References bookMonitor(), name, and AlCaHLTBitMon_QueryRunRegistry::string.

Referenced by plotting.Plot::draw().

53 { return name; }
Calibration::MVAComputer * PhysicsTools::MVATrainer::getTrainCalibration ( ) const

Definition at line 1239 of file MVATrainer.cc.

References bookConverter::compute(), findUntrainedComputers(), and makeTrainCalibration().

Referenced by PhysicsTools::TreeTrainer::iteration(), and setCrossValidation().

1240 {
1241  std::vector<AtomicId> compute, train;
1242  findUntrainedComputers(compute, train);
1243 
1244  if (train.empty())
1245  return nullptr;
1246 
1247  compute.push_back(nullptr);
1248  train.push_back(nullptr);
1249 
1250  return makeTrainCalibration(&compute.front(), &train.front());
1251 }
Calibration::MVAComputer * makeTrainCalibration(const AtomicId *compute, const AtomicId *train) const
Definition: MVATrainer.cc:972
def compute(min, max)
void findUntrainedComputers(std::vector< AtomicId > &compute, std::vector< AtomicId > &train) const
Definition: MVATrainer.cc:1193
SourceVariable * PhysicsTools::MVATrainer::getVariable ( AtomicId  source,
AtomicId  name 
) const
private

Definition at line 723 of file MVATrainer.cc.

References sources.

Referenced by createVariable(), and fillInputVars().

724 {
725  std::map<AtomicId, Source*>::const_iterator pos = sources.find(source);
726  if (pos == sources.end())
727  return nullptr;
728 
729  return pos->second->getOutput(name);
730 }
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
static std::string const source
Definition: EdmProvDump.cc:44
void PhysicsTools::MVATrainer::loadState ( )

Definition at line 581 of file MVATrainer.cc.

References processors, source, and sources.

Referenced by setCrossValidation().

582 {
583  for(std::vector<AtomicId>::const_iterator iter =
584  this->processors.begin();
585  iter != this->processors.end(); iter++) {
586  std::map<AtomicId, Source*>::const_iterator pos =
587  sources.find(*iter);
588  assert(pos != sources.end());
589  TrainProcessor *source =
590  dynamic_cast<TrainProcessor*>(pos->second);
591  assert(source);
592 
593  if (source->load())
594  edm::LogInfo("MVATrainer")
595  << source->getId() << " configuration for \""
596  << (const char*)source->getName()
597  << "\" loaded from file.";
598  }
599 }
std::vector< AtomicId > processors
Definition: MVATrainer.h:102
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
static std::string const source
Definition: EdmProvDump.cc:44
void PhysicsTools::MVATrainer::makeProcessor ( XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *  elem,
AtomicId  id,
const char *  name 
)
private

Definition at line 620 of file MVATrainer.cc.

References PhysicsTools::ProcessRegistry< Base_t, CalibBase_t, Parent_t >::Factory::create(), HTMLExport::elem(), Exception, fillInputVars(), fillOutputVars(), proc, processors, sources, AlCaHLTBitMon_QueryRunRegistry::string, and GlobalPosition_Frontier_DevDB_cff::tag.

Referenced by MVATrainer().

621 {
622  DOMElement *xmlInput = nullptr;
623  DOMElement *xmlConfig = nullptr;
624  DOMElement *xmlOutput = nullptr;
625  DOMElement *xmlData = nullptr;
626 
627  static struct NameExpect {
628  const char *tag;
629  bool mandatory;
630  DOMElement **elem;
631  } const expect[] = {
632  { "input", true, &xmlInput },
633  { "config", true, &xmlConfig },
634  { "output", true, &xmlOutput },
635  { "data", false, &xmlData },
636  { nullptr, }
637  };
638 
639  const NameExpect *cur = expect;
640  for(DOMNode *node = elem->getFirstChild();
641  node; node = node->getNextSibling()) {
642  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
643  continue;
644 
645  std::string tag = XMLSimpleStr(node->getNodeName());
646  DOMElement *elem = static_cast<DOMElement*>(node);
647 
648  if (!cur->tag)
649  throw cms::Exception("MVATrainer")
650  << "Superfluous tag " << tag
651  << "encountered in processor." << std::endl;
652  else if (tag != cur->tag && cur->mandatory)
653  throw cms::Exception("MVATrainer")
654  << "Expected tag " << cur->tag << ", got "
655  << tag << " instead in processor."
656  << std::endl;
657  else if (tag != cur->tag) {
658  cur++;
659  continue;
660  }
661  *(cur++)->elem = elem;
662  }
663 
664  while(cur->tag && !cur->mandatory)
665  cur++;
666  if (cur->tag)
667  throw cms::Exception("MVATrainer")
668  << "Unexpected end of processor configuration, "
669  << "expected tag " << cur->tag << "." << std::endl;
670 
671  std::unique_ptr<TrainProcessor> proc(
672  TrainProcessor::create(name, &id, this));
673  if (!proc.get())
674  throw cms::Exception("MVATrainer")
675  << "Variable processor trainer " << name
676  << " could not be instantiated. Most likely because"
677  " the trainer plugin for \"" << name << "\""
678  " does not exist." << std::endl;
679 
680  if (sources.find(id) != sources.end())
681  throw cms::Exception("MVATrainer")
682  << "Duplicate variable processor id "
683  << (const char*)id << "."
684  << std::endl;
685 
686  fillInputVars(proc->getInputs(), xmlInput);
687  fillOutputVars(proc->getOutputs(), proc.get(), xmlOutput);
688 
689  edm::LogInfo("MVATrainer")
690  << "Configuring " << (const char*)proc->getId()
691  << " \"" << (const char*)proc->getName() << "\".";
692  proc->configure(xmlConfig);
693 
694  sources.insert(std::make_pair(id, proc.release()));
695  processors.push_back(id);
696 }
TrainProcessor *const proc
Definition: MVATrainer.cc:101
void fillOutputVars(SourceVariableSet &vars, Source *source, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
Definition: MVATrainer.cc:836
std::vector< AtomicId > processors
Definition: MVATrainer.h:102
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:18
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
static Base_t * create(const char *name, const CalibBase_t *calib, Parent_t *parent=0)
void fillInputVars(SourceVariableSet &vars, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
Definition: MVATrainer.cc:744
Calibration::MVAComputer * PhysicsTools::MVATrainer::makeTrainCalibration ( const AtomicId compute,
const AtomicId train 
) const
private

Definition at line 972 of file MVATrainer.cc.

References calib, bookConverter::compute(), connectProcessors(), crossValidation, doAutoSave, spr::find(), mps_fire::i, createfilelist::int, interceptors, gen::n, PhysicsTools::Calibration::ProcForeach::nProcs, output, proc, processors, randomSeed, source, and sources.

Referenced by getTrainCalibration().

974 {
975  std::map<AtomicId, TrainInterceptor*> interceptors;
976  std::vector<MVATrainerComputer::Interceptor> baseInterceptors;
977  std::vector<CalibratedProcessor> processors;
978 
979  BaseInterceptor *interceptor = new InitInterceptor;
980  baseInterceptors.push_back(std::make_pair(0, interceptor));
981  processors.push_back(CalibratedProcessor(nullptr, interceptor));
982 
983  for(const AtomicId *iter = train; *iter; iter++) {
984  TrainProcessor *source;
985  if (*iter == kOutputId)
986  source = output;
987  else {
988  std::map<AtomicId, Source*>::const_iterator pos =
989  sources.find(*iter);
990  assert(pos != sources.end());
991  source = dynamic_cast<TrainProcessor*>(pos->second);
992  }
993  assert(source);
994 
995  interceptors[*iter] = new TrainInterceptor(source);
996  }
997 
998  auto_cleaner<Calibration::VarProcessor> autoClean;
999 
1000  std::set<AtomicId> done;
1001  for(const AtomicId *iter = compute; *iter; iter++) {
1002  if (done.erase(*iter))
1003  continue;
1004 
1005  std::map<AtomicId, Source*>::const_iterator pos =
1006  sources.find(*iter);
1007  assert(pos != sources.end());
1008  TrainProcessor *source =
1009  dynamic_cast<TrainProcessor*>(pos->second);
1010  assert(source);
1011  assert(source->isTrained());
1012 
1013  Calibration::VarProcessor *proc = source->getCalibration();
1014  if (!proc)
1015  continue;
1016 
1017  autoClean.add(proc);
1018  processors.push_back(CalibratedProcessor(source, proc));
1019 
1020  Calibration::ProcForeach *looper =
1021  dynamic_cast<Calibration::ProcForeach*>(proc);
1022  if (looper) {
1023  std::vector<AtomicId>::const_iterator pos2 =
1024  std::find(this->processors.begin(),
1025  this->processors.end(), *iter);
1026  assert(pos2 != this->processors.end());
1027  ++pos2;
1028  unsigned int n = 0;
1029  for(int i = 0; i < (int)looper->nProcs; ++i, ++pos2) {
1030  assert(pos2 != this->processors.end());
1031 
1032  const AtomicId *iter2 = compute;
1033  while(*iter2) {
1034  if (*iter2 == *pos2)
1035  break;
1036  iter2++;
1037  }
1038 
1039  if (*iter2) {
1040  n++;
1041  done.insert(*iter2);
1042  pos = sources.find(*iter2);
1043  assert(pos != sources.end());
1044  TrainProcessor *source =
1045  dynamic_cast<TrainProcessor*>(
1046  pos->second);
1047  assert(source);
1048  assert(source->isTrained());
1049 
1050  proc = source->getCalibration();
1051  if (proc) {
1052  autoClean.add(proc);
1053  processors.push_back(
1054  CalibratedProcessor(
1055  source, proc));
1056  }
1057  }
1058 
1059  std::map<AtomicId, TrainInterceptor*>::iterator
1060  pos3 = interceptors.find(*pos2);
1061  if (pos3 != interceptors.end()) {
1062  n++;
1063  baseInterceptors.push_back(
1064  std::make_pair(processors.size(),
1065  pos3->second));
1066  processors.push_back(
1067  CalibratedProcessor(
1068  pos3->second->getProcessor(),
1069  pos3->second));
1070  interceptors.erase(pos3);
1071  }
1072  }
1073 
1074  looper->nProcs = n;
1075  if (!n) {
1076  baseInterceptors.pop_back();
1077  processors.pop_back();
1078  }
1079  }
1080  }
1081 
1082  for(std::map<AtomicId, TrainInterceptor*>::const_iterator iter =
1083  interceptors.begin(); iter != interceptors.end(); ++iter) {
1084 
1085  TrainProcessor *proc = iter->second->getProcessor();
1086  baseInterceptors.push_back(std::make_pair(processors.size(),
1087  iter->second));
1088  processors.push_back(CalibratedProcessor(proc, iter->second));
1089  }
1090 
1091  std::unique_ptr<Calibration::MVAComputer> calib(
1092  new MVATrainerComputer(baseInterceptors, doAutoSave,
1094 
1095  connectProcessors(calib.get(), processors, true);
1096 
1097  return calib.release();
1098 }
TrainProcessor *const proc
Definition: MVATrainer.cc:101
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:20
Definition: looper.py:1
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
std::vector< AtomicId > processors
Definition: MVATrainer.h:102
void connectProcessors(Calibration::MVAComputer *calib, const std::vector< CalibratedProcessor > &procs, bool withTarget) const
Definition: MVATrainer.cc:883
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
std::vector< Interceptor > interceptors
Definition: MVATrainer.cc:134
def compute(min, max)
TrainProcessor * output
Definition: MVATrainer.h:104
static std::string const source
Definition: EdmProvDump.cc:44
static const AtomicId kOutputId("__OUTPUT__")
void PhysicsTools::MVATrainer::saveState ( )

Definition at line 601 of file MVATrainer.cc.

References doCleanup, processors, source, and sources.

Referenced by setCrossValidation().

602 {
603  doCleanup = false;
604 
605  for(std::vector<AtomicId>::const_iterator iter =
606  this->processors.begin();
607  iter != this->processors.end(); iter++) {
608  std::map<AtomicId, Source*>::const_iterator pos =
609  sources.find(*iter);
610  assert(pos != sources.end());
611  TrainProcessor *source =
612  dynamic_cast<TrainProcessor*>(pos->second);
613  assert(source);
614 
615  if (source->isTrained())
616  source->save();
617  }
618 }
std::vector< AtomicId > processors
Definition: MVATrainer.h:102
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
static std::string const source
Definition: EdmProvDump.cc:44
void PhysicsTools::MVATrainer::setAutoSave ( bool  autoSave)
inline

Definition at line 33 of file MVATrainer.h.

References doAutoSave.

33 { doAutoSave = autoSave; }
void PhysicsTools::MVATrainer::setCleanup ( bool  cleanup)
inline

Definition at line 34 of file MVATrainer.h.

References edm::cleanup(), and doCleanup.

34 { doCleanup = cleanup; }
static void cleanup(const Factory::MakerMap::value_type &v)
Definition: Factory.cc:12
void PhysicsTools::MVATrainer::setCrossValidation ( double  split)
inline

Definition at line 37 of file MVATrainer.h.

References crossValidation, doneTraining(), getCalibration(), getTrainCalibration(), loadState(), proc, saveState(), PhysicsTools::split(), AlCaHLTBitMon_QueryRunRegistry::string, and trainFileName().

Referenced by PhysicsTools::TreeTrainer::train().

static std::vector< std::string > split(const std::string line, char delim)
Definition: MLP.cc:18
void PhysicsTools::MVATrainer::setMonitoring ( bool  monitoring)
inline

Definition at line 35 of file MVATrainer.h.

References doMonitoring, and monitoring.

Referenced by PhysicsTools::TreeTrainer::train().

std::unique_ptr< TrainerMonitoring > monitoring
Definition: MVATrainer.h:106
void PhysicsTools::MVATrainer::setRandomSeed ( UInt_t  seed)
inline
std::string PhysicsTools::MVATrainer::trainFileName ( const TrainProcessor proc,
const std::string &  ext,
const std::string &  arg = "" 
) const

Definition at line 698 of file MVATrainer.cc.

References PhysicsTools::Source::getName(), PhysicsTools::stdStringPrintf(), AlCaHLTBitMon_QueryRunRegistry::string, and trainFileMask.

Referenced by setCrossValidation().

701 {
702  std::string arg_ = !arg.empty() ? ("_" + arg) : "";
703  return stdStringPrintf(trainFileMask.c_str(),
704  (const char*)proc->getName(),
705  arg_.c_str(), ext.c_str());
706 }
TrainProcessor *const proc
Definition: MVATrainer.cc:101
A arg
Definition: Factorize.h:37
std::string trainFileMask
Definition: MVATrainer.h:108
Definition: memstream.h:15
static std::string stdStringPrintf(const char *format,...)
Definition: MVATrainer.cc:181

Member Data Documentation

double PhysicsTools::MVATrainer::crossValidation
private

Definition at line 115 of file MVATrainer.h.

Referenced by makeTrainCalibration(), and setCrossValidation().

bool PhysicsTools::MVATrainer::doAutoSave
private

Definition at line 110 of file MVATrainer.h.

Referenced by makeTrainCalibration(), and setAutoSave().

bool PhysicsTools::MVATrainer::doCleanup
private

Definition at line 111 of file MVATrainer.h.

Referenced by saveState(), setCleanup(), and ~MVATrainer().

bool PhysicsTools::MVATrainer::doMonitoring
private

Definition at line 112 of file MVATrainer.h.

Referenced by bookMonitor(), findUntrainedComputers(), and setMonitoring().

Source* PhysicsTools::MVATrainer::input
private

Definition at line 103 of file MVATrainer.h.

Referenced by connectProcessors(), fillInputVars(), findUntrainedComputers(), and MVATrainer().

const AtomicId PhysicsTools::MVATrainer::kTargetId
static
const AtomicId PhysicsTools::MVATrainer::kWeightId
static
std::unique_ptr<TrainerMonitoring> PhysicsTools::MVATrainer::monitoring
private

Definition at line 106 of file MVATrainer.h.

Referenced by bookMonitor(), setMonitoring(), and ~MVATrainer().

std::string PhysicsTools::MVATrainer::name
private
TrainProcessor* PhysicsTools::MVATrainer::output
private
std::vector<AtomicId> PhysicsTools::MVATrainer::processors
private
UInt_t PhysicsTools::MVATrainer::randomSeed
private

Definition at line 114 of file MVATrainer.h.

Referenced by makeTrainCalibration(), and setRandomSeed().

std::map<AtomicId, Source*> PhysicsTools::MVATrainer::sources
private
std::string PhysicsTools::MVATrainer::trainFileMask
private

Definition at line 108 of file MVATrainer.h.

Referenced by bookMonitor(), MVATrainer(), and trainFileName().

std::vector<SourceVariable*> PhysicsTools::MVATrainer::variables
private

Definition at line 101 of file MVATrainer.h.

Referenced by connectProcessors(), createVariable(), fillInputVars(), and ~MVATrainer().

std::unique_ptr<XMLDocument> PhysicsTools::MVATrainer::xml
private

Definition at line 107 of file MVATrainer.h.

Referenced by MVATrainer().