3 #include <ext/functional> 15 #include <xercesc/dom/DOM.hpp> 43 class MVATrainerComputer;
45 class BaseInterceptor :
public Calibration::Interceptor {
47 BaseInterceptor() :
calib(0) {}
48 virtual ~BaseInterceptor() {}
50 inline void setCalibration(MVATrainerComputer *
calib)
51 { this->calib =
calib; }
53 virtual std::vector<Variable::Flags>
54 configure(
const MVAComputer *
computer,
unsigned int n,
55 const std::vector<Variable::Flags> &
flags) = 0;
58 intercept(
const std::vector<double> *
values)
const = 0;
60 virtual void init() {}
61 virtual void finish(
bool save) {}
67 class InitInterceptor :
public BaseInterceptor {
70 virtual ~InitInterceptor() {}
72 virtual std::vector<Variable::Flags>
73 configure(
const MVAComputer *
computer,
unsigned int n,
74 const std::vector<Variable::Flags> &
flags)
override;
77 intercept(
const std::vector<double> *
values)
const override;
80 class TrainInterceptor :
public BaseInterceptor {
82 TrainInterceptor(TrainProcessor *
proc) :
proc(proc) {}
83 virtual ~TrainInterceptor() {}
85 inline TrainProcessor *getProcessor()
const {
return proc; }
87 virtual std::vector<Variable::Flags>
88 configure(
const MVAComputer *
computer,
unsigned int n,
89 const std::vector<Variable::Flags> &
flags)
override;
92 intercept(
const std::vector<double> *
values)
const override;
94 virtual void init()
override;
95 virtual void finish(
bool save)
override;
100 mutable std::vector<std::vector<double> >
tmp;
104 class MVATrainerComputer :
public TrainMVAComputerCalibration {
106 typedef std::pair<unsigned int, BaseInterceptor*> Interceptor;
108 MVATrainerComputer(
const std::vector<Interceptor>
110 bool autoSave, UInt_t
seed,
double split);
112 virtual ~MVATrainerComputer();
114 virtual std::vector<Calibration::VarProcessor*>
115 getProcessors()
const override;
116 virtual void initFlags(std::vector<Variable::Flags>
117 &
flags)
const override;
119 void configured(BaseInterceptor *interceptor)
const;
124 { flags.push_back(flag); }
126 inline bool useForTraining()
const {
return splitResult; }
127 inline bool useForTesting()
const 130 inline bool isConfigured()
const 146 struct deleter :
public std::unary_function<T*, void> {
147 inline void operator() (
T *ptr)
const {
delete ptr; }
151 struct auto_cleaner {
152 inline ~auto_cleaner()
155 inline void add(
T *ptr) {
clean.push_back(ptr); }
162 unsigned int size = std::min<unsigned int>(128, std::strlen(format));
165 int n = std::vsnprintf(buffer, size, format, va);
166 if (n >= 0 && (
unsigned int)n < size)
175 buffer =
new char[
size];
186 va_start(va, format);
194 std::vector<Variable::Flags>
196 const std::vector<Variable::Flags> &
flags)
198 calib->configured(
this);
203 InitInterceptor::intercept(
const std::vector<double> *
values)
const 211 std::vector<Variable::Flags>
212 TrainInterceptor::configure(
const MVAComputer *computer,
unsigned int n,
213 const std::vector<Variable::Flags> &flags)
220 std::vector<SourceVariable*>
inputs = inputSet.
get(
true);
222 std::vector<SourceVariable*>::const_iterator
pos;
224 assert(pos != inputs.end());
227 assert(pos != inputs.end());
230 calib->configured(
this);
233 if (targetIdx < weightIdx) {
234 result.erase(result.begin() +
weightIdx);
235 result.erase(result.begin() +
targetIdx);
237 result.erase(result.begin() +
targetIdx);
238 result.erase(result.begin() +
weightIdx);
241 proc->passFlags(result);
244 result.resize(n,
proc->getDefaultFlags());
248 if (targetIdx >= 2 || weightIdx >= 2)
257 <<
"TrainProcessor \"" << (
const char*)
proc->getName()
258 <<
"\" training iteration starting...";
260 proc->doTrainBegin();
264 TrainInterceptor::intercept(
const std::vector<double> *values)
const 269 <<
"Trainer input lacks target variable." 273 <<
"Multiple targets supplied in input." 281 <<
"Multiple weights supplied in input." 287 proc->doTrainData(values + 2, target > 0.5, weight,
288 calib->useForTraining(),
289 calib->useForTesting());
291 std::vector<std::vector<double> >::iterator
pos =
tmp.begin();
292 for(
unsigned int i = 0; pos !=
tmp.end();
i++)
297 calib->useForTraining(),
298 calib->useForTesting());
304 void TrainInterceptor::finish(
bool save)
309 <<
"... processor \"" << (
const char*)
proc->getName()
310 <<
"\" training iteration done.";
312 if (
proc->isTrained()) {
314 <<
"* Completed training of \"" 315 << (
const char*)
proc->getName() <<
"\".";
324 MVATrainerComputer::MVATrainerComputer(
const std::vector<Interceptor>
330 for(std::vector<Interceptor>::const_iterator iter =
331 interceptors.begin(); iter != interceptors.end(); ++iter)
332 iter->second->setCalibration(
this);
335 MVATrainerComputer::~MVATrainerComputer()
339 for(std::vector<Interceptor>::const_iterator iter =
340 interceptors.begin(); iter != interceptors.end(); ++iter)
344 std::vector<Calibration::VarProcessor*>
345 MVATrainerComputer::getProcessors()
const 347 std::vector<Calibration::VarProcessor*> processors =
350 for(std::vector<Interceptor>::const_iterator iter =
351 interceptors.begin(); iter != interceptors.end(); ++iter)
353 processors.insert(processors.begin() + iter->first,
359 void MVATrainerComputer::initFlags(std::vector<Variable::Flags> &flags)
const 361 assert(flags.size() == this->flags.size());
365 void MVATrainerComputer::configured(BaseInterceptor *interceptor)
const 369 for(std::vector<Interceptor>::const_iterator iter =
370 interceptors.begin();
371 iter != interceptors.end(); ++iter)
372 iter->second->init();
380 void MVATrainerComputer::done()
382 if (isConfigured()) {
383 for(std::vector<Interceptor>::const_iterator iter =
384 interceptors.begin();
385 iter != interceptors.end(); ++iter)
400 return id == MVATrainer::kTargetId ||
401 id == MVATrainer::kWeightId ||
408 for(std::string::const_iterator iter = in.begin();
409 iter != in.end(); ++iter) {
423 const char *styleSheet) :
426 doMonitoring(
false), randomSeed(65539), crossValidation(0.0)
432 "PhysicsTools/MVATrainer/data/MVATrainer.xsl")
443 DOMNode *node =
xml->getRootNode();
445 if (std::strcmp(
XMLSimpleStr(node->getNodeName()),
"MVATrainer") != 0)
447 <<
"Invalid XML root node." << std::endl;
454 } state = STATE_GENERAL;
456 for(node = node->getFirstChild();
457 node; node = node->getNextSibling()) {
458 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
462 DOMElement *
elem =
static_cast<DOMElement*
>(node);
465 case STATE_GENERAL: {
466 if (name !=
"general")
468 <<
"Expected general config as first " 471 for(DOMNode *subNode = elem->getFirstChild();
472 subNode; subNode = subNode->getNextSibling()) {
473 if (subNode->getNodeType() !=
474 DOMNode::ELEMENT_NODE)
478 subNode->getNodeName()),
"option") != 0)
480 <<
"Expected option tag." 483 elem =
static_cast<DOMElement*
>(subNode);
484 name = XMLDocument::readAttribute<std::string>(
487 elem->getTextContent());
491 else if (name ==
"trainfiles")
495 <<
"Unknown option \"" 496 << name <<
"\"." << std::endl;
504 <<
"Expected input config as second " 507 AtomicId id = XMLDocument::readAttribute<std::string>(
510 input->getOutputs().append(
514 input->getOutputs().append(
518 sources.insert(std::make_pair(
id, input));
521 state = STATE_MIDDLE;
524 if (name ==
"output") {
531 }
else if (name !=
"processor")
533 <<
"Unexpected tag after input " 534 "config." << std::endl;
536 AtomicId id = XMLDocument::readAttribute<std::string>(
539 XMLDocument::readAttribute<std::string>(
546 <<
"Unexpected tag found after output." 552 if (state == STATE_FIRST)
554 <<
"Expected input variable config." << std::endl;
555 else if (state == STATE_MIDDLE)
557 <<
"Expected output variable config." << std::endl;
568 for(std::map<AtomicId, Source*>::const_iterator iter =
sources.begin();
569 iter !=
sources.end(); iter++) {
580 deleter<SourceVariable>());
585 for(std::vector<AtomicId>::const_iterator iter =
588 std::map<AtomicId, Source*>::const_iterator
pos =
597 << source->getId() <<
" configuration for \"" 598 << (
const char*)source->getName()
599 <<
"\" loaded from file.";
607 for(std::vector<AtomicId>::const_iterator iter =
610 std::map<AtomicId, Source*>::const_iterator
pos =
617 if (source->isTrained())
624 DOMElement *xmlInput = 0;
625 DOMElement *xmlConfig = 0;
626 DOMElement *xmlOutput = 0;
627 DOMElement *xmlData = 0;
629 static struct NameExpect {
634 {
"input",
true, &xmlInput },
635 {
"config",
true, &xmlConfig },
636 {
"output",
true, &xmlOutput },
637 {
"data",
false, &xmlData },
641 const NameExpect *cur = expect;
642 for(DOMNode *node = elem->getFirstChild();
643 node; node = node->getNextSibling()) {
644 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
648 DOMElement *elem =
static_cast<DOMElement*
>(node);
652 <<
"Superfluous tag " << tag
653 <<
"encountered in processor." << std::endl;
654 else if (tag != cur->tag && cur->mandatory)
656 <<
"Expected tag " << cur->tag <<
", got " 657 << tag <<
" instead in processor." 659 else if (tag != cur->tag) {
663 *(cur++)->elem = elem;
666 while(cur->tag && !cur->mandatory)
670 <<
"Unexpected end of processor configuration, " 671 <<
"expected tag " << cur->tag <<
"." << std::endl;
673 std::unique_ptr<TrainProcessor>
proc(
677 <<
"Variable processor trainer " << name
678 <<
" could not be instantiated. Most likely because" 679 " the trainer plugin for \"" << name <<
"\"" 680 " does not exist." << std::endl;
684 <<
"Duplicate variable processor id " 685 << (
const char*)
id <<
"." 692 <<
"Configuring " << (
const char*)proc->getId()
693 <<
" \"" << (
const char*)proc->getName() <<
"\".";
694 proc->configure(xmlConfig);
696 sources.insert(std::make_pair(
id, proc.release()));
707 arg_.c_str(), ext.c_str());
718 "monitoring",
"",
"root");
727 std::map<AtomicId, Source*>::const_iterator
pos =
sources.find(source);
731 return pos->second->getOutput(name);
749 std::vector<SourceVariable*>
tmp;
753 for(DOMNode *node = xml->getFirstChild(); node;
754 node = node->getNextSibling()) {
755 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
758 if (std::strcmp(
XMLSimpleStr(node->getNodeName()),
"var") != 0)
760 <<
"Invalid input variable node." << std::endl;
762 DOMElement *
elem =
static_cast<DOMElement*
>(node);
766 AtomicId name = XMLDocument::readAttribute<std::string>(
772 <<
"Input variable " << (
const char*)source
773 <<
":" << (
const char*)name
774 <<
" not found." << std::endl;
776 if (XMLDocument::readAttribute<bool>(elem,
"target",
false)) {
779 <<
"Target variable defined twice" 783 if (XMLDocument::readAttribute<bool>(elem,
"weight",
false)) {
786 <<
"Weight variable defined twice" 797 tmp.insert(tmp.begin() +
804 tmp.insert(tmp.begin(), 1,
target);
808 for(std::vector<SourceVariable*>::const_iterator iter =
variables.begin();
810 std::vector<SourceVariable*>::const_iterator
pos =
811 std::find(tmp.begin(), tmp.end(), *iter);
812 if (pos == tmp.end())
818 else if (*iter == weight)
823 if (vars.
append(*iter, magic, pos - tmp.begin())) {
827 <<
"Input variable " << (
const char*)source
828 <<
":" << (
const char*)name
829 <<
" defined twice." << std::endl;
835 assert(tmp.size() ==
n);
841 for(DOMNode *node = xml->getFirstChild(); node;
842 node = node->getNextSibling()) {
843 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
846 if (std::strcmp(
XMLSimpleStr(node->getNodeName()),
"var") != 0)
848 <<
"Invalid output variable node." 851 DOMElement *
elem =
static_cast<DOMElement*
>(node);
853 AtomicId name = XMLDocument::readAttribute<std::string>(
857 <<
"Output variable tag missing name." 861 <<
"Cannot use magic variable names in output." 866 if (XMLDocument::readAttribute<bool>(elem,
"optional",
true))
870 if (XMLDocument::readAttribute<bool>(elem,
"multiple",
true))
875 if (!var || vars.
append(var))
877 <<
"Output variable " 878 << (
const char*)source->
getName()
879 <<
":" << (
const char*)name
880 <<
" defined twice." << std::endl;
886 const std::vector<CalibratedProcessor> &procs,
887 bool withTarget)
const 889 std::map<SourceVariable*, unsigned int> vars;
890 unsigned int size = 0;
892 MVATrainerComputer *trainCalib =
893 dynamic_cast<MVATrainerComputer*
>(
calib);
895 for(
unsigned int i = 0;
897 if (
i < 2 && !withTarget)
905 calib->
inputSet.push_back(calibVar);
907 trainCalib->addFlag(var->
getFlags());
910 for(std::vector<CalibratedProcessor>::const_iterator iter =
911 procs.begin(); iter != procs.end(); iter++) {
912 bool isInterceptor =
dynamic_cast<BaseInterceptor*
>(
917 unsigned int last = 0;
918 std::vector<SourceVariable*> inoutVars;
920 inoutVars = iter->processor->getInputs().get(
922 for(std::vector<SourceVariable*>::const_iterator iter2 =
923 inoutVars.begin(); iter2 != inoutVars.end(); iter2++) {
925 unsigned int>::const_iterator
pos =
928 assert(pos != vars.end());
930 if (pos->second < last)
932 <<
"Input variables not declared " 933 "in order of appearance in \"" 934 << (
const char*)iter->processor->getName()
935 <<
"\"." << std::endl;
937 inputSet[last = pos->second] =
true;
940 assert(!isInterceptor || withTarget);
953 inoutVars = iter->processor->getOutputs().get();
954 for(std::vector<SourceVariable*>::const_iterator iter =
955 inoutVars.begin(); iter != inoutVars.end(); iter++) {
957 vars[*iter] = size++;
963 <<
"Exactly one output variable has to be specified." 967 std::map<SourceVariable*, unsigned int>::const_iterator
pos =
969 if (pos != vars.end())
970 calib->
output = pos->second;
978 std::vector<MVATrainerComputer::Interceptor> baseInterceptors;
981 BaseInterceptor *interceptor =
new InitInterceptor;
982 baseInterceptors.push_back(std::make_pair(0, interceptor));
985 for(
const AtomicId *iter = train; *iter; iter++) {
987 if (*iter == kOutputId)
990 std::map<AtomicId, Source*>::const_iterator
pos =
997 interceptors[*iter] =
new TrainInterceptor(source);
1000 auto_cleaner<Calibration::VarProcessor> autoClean;
1002 std::set<AtomicId> done;
1003 for(
const AtomicId *iter = compute; *iter; iter++) {
1004 if (done.erase(*iter))
1007 std::map<AtomicId, Source*>::const_iterator
pos =
1013 assert(source->isTrained());
1019 autoClean.add(proc);
1025 std::vector<AtomicId>::const_iterator pos2 =
1027 this->processors.end(), *iter);
1028 assert(pos2 != this->processors.end());
1032 assert(pos2 != this->processors.end());
1036 if (*iter2 == *pos2)
1043 done.insert(*iter2);
1050 assert(source->isTrained());
1052 proc = source->getCalibration();
1054 autoClean.add(proc);
1055 processors.push_back(
1061 std::map<AtomicId, TrainInterceptor*>::iterator
1062 pos3 = interceptors.find(*pos2);
1063 if (pos3 != interceptors.end()) {
1065 baseInterceptors.push_back(
1066 std::make_pair(processors.size(),
1068 processors.push_back(
1070 pos3->second->getProcessor(),
1072 interceptors.erase(pos3);
1078 baseInterceptors.pop_back();
1079 processors.pop_back();
1084 for(std::map<AtomicId, TrainInterceptor*>::const_iterator iter =
1085 interceptors.begin(); iter != interceptors.end(); ++iter) {
1088 baseInterceptors.push_back(std::make_pair(processors.size(),
1093 std::unique_ptr<Calibration::MVAComputer>
calib(
1094 new MVATrainerComputer(baseInterceptors,
doAutoSave,
1099 return calib.release();
1104 MVATrainerComputer *
calib =
1105 dynamic_cast<MVATrainerComputer*
>(trainCalibration);
1109 <<
"Invalid training calibration passed to " 1110 "doneTraining()" << std::endl;
1117 std::set<Source*> toCheck;
1120 std::set<Source*> done;
1121 while(!toCheck.empty()) {
1123 toCheck.erase(toCheck.begin());
1126 for(std::vector<SourceVariable*>::const_iterator iter =
1127 inputs.begin(); iter != inputs.end(); ++iter) {
1128 source = (*iter)->getSource();
1129 if (done.insert(source).second)
1130 toCheck.insert(source);
1134 std::vector<AtomicId>
result;
1135 for(std::vector<AtomicId>::const_iterator iter =
processors.begin();
1137 std::map<AtomicId, Source*>::const_iterator
pos =
1139 if (pos !=
sources.end() && done.count(pos->second))
1140 result.push_back(*iter);
1150 std::unique_ptr<Calibration::MVAComputer>
calib(
1154 for(std::vector<AtomicId>::const_iterator iter = used.begin();
1155 iter != used.end(); iter++) {
1156 std::map<AtomicId, Source*>::const_iterator
pos =
1162 if (!source->isTrained())
1172 std::vector<AtomicId>::const_iterator
begin =
1174 this->processors.end(), *iter);
1175 assert(this->processors.end() - begin >
1178 std::vector<AtomicId>::const_iterator
end =
1179 begin +
foreach->nProcs;
1180 foreach->nProcs = 0;
1181 for(std::vector<AtomicId>::const_iterator iter2 =
1182 iter; iter2 != used.end(); ++iter2)
1192 return calib.release();
1196 std::vector<AtomicId> &train)
const 1201 std::set<Source*> trainedSources;
1202 trainedSources.insert(
input);
1204 for(std::vector<AtomicId>::const_iterator iter =
1206 std::map<AtomicId, Source*>::const_iterator
pos =
1213 bool trainedDeps =
true;
1214 std::vector<SourceVariable*> inputVars =
1215 proc->getInputs().get();
1216 for(std::vector<SourceVariable*>::const_iterator iter2 =
1217 inputVars.begin(); iter2 != inputVars.end(); iter2++) {
1218 if (trainedSources.find((*iter2)->getSource())
1219 == trainedSources.end()) {
1220 trainedDeps =
false;
1228 if (proc->isTrained()) {
1229 trainedSources.insert(proc);
1230 compute.push_back(proc->getName());
1232 train.push_back(proc->getName());
1237 != trainedSources.end())
1238 train.push_back(kOutputId);
1243 std::vector<AtomicId>
compute, train;
1249 compute.push_back(0);
TrainProcessor *const proc
std::vector< Variable::Flags > flags
void add(const std::vector< const T * > &source, std::vector< const T * > &dest)
#define XERCES_CPP_NAMESPACE_QUALIFIER
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
static std::string const input
MVATrainerComputer * calib
def elem(elemtype, innerHTML='', html_class='', kwargs)
std::vector< Interceptor > interceptors
std::vector< std::vector< double > > tmp
std::string fullPath() const
static std::string const source