00001 #include <assert.h>
00002 #include <functional>
00003 #include <ext/functional>
00004 #include <algorithm>
00005 #include <iostream>
00006 #include <cstdarg>
00007 #include <cstring>
00008 #include <cstdio>
00009 #include <string>
00010 #include <vector>
00011 #include <memory>
00012 #include <map>
00013 #include <set>
00014
00015 #include <xercesc/dom/DOM.hpp>
00016
00017 #include <TRandom.h>
00018
00019 #include "FWCore/Utilities/interface/Exception.h"
00020 #include "FWCore/ParameterSet/interface/FileInPath.h"
00021 #include "FWCore/MessageLogger/interface/MessageLogger.h"
00022
00023 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00024 #include "PhysicsTools/MVAComputer/interface/BitSet.h"
00025 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00026 #include "PhysicsTools/MVAComputer/interface/Variable.h"
00027
00028 #include "PhysicsTools/MVATrainer/interface/Interceptor.h"
00029 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00030 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00031 #include "PhysicsTools/MVATrainer/interface/XMLUniStr.h"
00032 #include "PhysicsTools/MVATrainer/interface/Source.h"
00033 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00034 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00035 #include "PhysicsTools/MVATrainer/interface/TrainerMonitoring.h"
00036 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00037
00038 XERCES_CPP_NAMESPACE_USE
00039
00040 namespace PhysicsTools {
00041
00042 namespace {
00043 class MVATrainerComputer;
00044
00045 class BaseInterceptor : public Calibration::Interceptor {
00046 public:
00047 BaseInterceptor() : calib(0) {}
00048 virtual ~BaseInterceptor() {}
00049
00050 inline void setCalibration(MVATrainerComputer *calib)
00051 { this->calib = calib; }
00052
00053 virtual std::vector<Variable::Flags>
00054 configure(const MVAComputer *computer, unsigned int n,
00055 const std::vector<Variable::Flags> &flags) = 0;
00056
00057 virtual double
00058 intercept(const std::vector<double> *values) const = 0;
00059
00060 virtual void init() {}
00061 virtual void finish(bool save) {}
00062
00063 protected:
00064 MVATrainerComputer *calib;
00065 };
00066
00067 class InitInterceptor : public BaseInterceptor {
00068 public:
00069 InitInterceptor() {}
00070 virtual ~InitInterceptor() {}
00071
00072 virtual std::vector<Variable::Flags>
00073 configure(const MVAComputer *computer, unsigned int n,
00074 const std::vector<Variable::Flags> &flags);
00075
00076 virtual double
00077 intercept(const std::vector<double> *values) const;
00078 };
00079
00080 class TrainInterceptor : public BaseInterceptor {
00081 public:
00082 TrainInterceptor(TrainProcessor *proc) : proc(proc) {}
00083 virtual ~TrainInterceptor() {}
00084
00085 inline TrainProcessor *getProcessor() const { return proc; }
00086
00087 virtual std::vector<Variable::Flags>
00088 configure(const MVAComputer *computer, unsigned int n,
00089 const std::vector<Variable::Flags> &flags);
00090
00091 virtual double
00092 intercept(const std::vector<double> *values) const;
00093
00094 virtual void init();
00095 virtual void finish(bool save);
00096
00097 private:
00098 unsigned int targetIdx;
00099 unsigned int weightIdx;
00100 mutable std::vector<std::vector<double> > tmp;
00101 TrainProcessor *const proc;
00102 };
00103
00104 class MVATrainerComputer : public TrainMVAComputerCalibration {
00105 public:
00106 typedef std::pair<unsigned int, BaseInterceptor*> Interceptor;
00107
00108 MVATrainerComputer(const std::vector<Interceptor>
00109 &interceptors,
00110 bool autoSave, UInt_t seed, double split);
00111
00112 virtual ~MVATrainerComputer();
00113
00114 virtual std::vector<Calibration::VarProcessor*>
00115 getProcessors() const;
00116 virtual void initFlags(std::vector<Variable::Flags>
00117 &flags) const;
00118
00119 void configured(BaseInterceptor *interceptor) const;
00120 void next();
00121 void done();
00122
00123 inline void addFlag(Variable::Flags flag)
00124 { flags.push_back(flag); }
00125
00126 inline bool useForTraining() const { return splitResult; }
00127 inline bool useForTesting() const
00128 { return split <= 0.0 || !splitResult; }
00129
00130 inline bool isConfigured() const
00131 { return nConfigured == interceptors.size(); }
00132
00133 private:
00134 std::vector<Interceptor> interceptors;
00135 std::vector<Variable::Flags> flags;
00136 mutable unsigned int nConfigured;
00137 bool doAutoSave;
00138 TRandom random;
00139 double split;
00140 bool splitResult;
00141 };
00142
00143
00144
00145 template<typename T>
00146 struct deleter : public std::unary_function<T*, void> {
00147 inline void operator() (T *ptr) const { delete ptr; }
00148 };
00149
00150 template<typename T>
00151 struct auto_cleaner {
00152 inline ~auto_cleaner()
00153 { std::for_each(clean.begin(), clean.end(), deleter<T>()); }
00154
00155 inline void add(T *ptr) { clean.push_back(ptr); }
00156 std::vector<T*> clean;
00157 };
00158 }
00159
00160 static std::string stdStringVPrintf(const char *format, std::va_list va)
00161 {
00162 unsigned int size = std::min<unsigned int>(128, std::strlen(format));
00163 char *buffer = new char[size];
00164 for(;;) {
00165 int n = std::vsnprintf(buffer, size, format, va);
00166 if (n >= 0 && (unsigned int)n < size)
00167 break;
00168
00169 if (n >= 0)
00170 size = n + 1;
00171 else
00172 size *= 2;
00173
00174 delete[] buffer;
00175 buffer = new char[size];
00176 }
00177
00178 std::string result(buffer);
00179 delete[] buffer;
00180 return result;
00181 }
00182
00183 static std::string stdStringPrintf(const char *format, ...)
00184 {
00185 std::va_list va;
00186 va_start(va, format);
00187 std::string result = stdStringVPrintf(format, va);
00188 va_end(va);
00189 return result;
00190 }
00191
00192
00193
00194 std::vector<Variable::Flags>
00195 InitInterceptor::configure(const MVAComputer *computer, unsigned int n,
00196 const std::vector<Variable::Flags> &flags)
00197 {
00198 calib->configured(this);
00199 return std::vector<Variable::Flags>(n, Variable::FLAG_ALL);
00200 }
00201
00202 double
00203 InitInterceptor::intercept(const std::vector<double> *values) const
00204 {
00205 calib->next();
00206 return 0.0;
00207 }
00208
00209
00210
00211 std::vector<Variable::Flags>
00212 TrainInterceptor::configure(const MVAComputer *computer, unsigned int n,
00213 const std::vector<Variable::Flags> &flags)
00214 {
00215 const SourceVariableSet &inputSet =
00216 const_cast<const TrainProcessor*>(proc)->getInputs();
00217 SourceVariable *target = inputSet.find(SourceVariableSet::kTarget);
00218 SourceVariable *weight = inputSet.find(SourceVariableSet::kWeight);
00219
00220 std::vector<SourceVariable*> inputs = inputSet.get(true);
00221
00222 std::vector<SourceVariable*>::const_iterator pos;
00223 pos = std::find(inputs.begin(), inputs.end(), target);
00224 assert(pos != inputs.end());
00225 targetIdx = pos - inputs.begin();
00226 pos = std::find(inputs.begin(), inputs.end(), weight);
00227 assert(pos != inputs.end());
00228 weightIdx = pos - inputs.begin();
00229
00230 calib->configured(this);
00231
00232 std::vector<Variable::Flags> result = flags;
00233 if (targetIdx < weightIdx) {
00234 result.erase(result.begin() + weightIdx);
00235 result.erase(result.begin() + targetIdx);
00236 } else {
00237 result.erase(result.begin() + targetIdx);
00238 result.erase(result.begin() + weightIdx);
00239 }
00240
00241 proc->passFlags(result);
00242
00243 result.clear();
00244 result.resize(n, proc->getDefaultFlags());
00245 result[targetIdx] = Variable::FLAG_NONE;
00246 result[weightIdx] = Variable::FLAG_OPTIONAL;
00247
00248 if (targetIdx >= 2 || weightIdx >= 2)
00249 tmp.resize(n - 2);
00250
00251 return result;
00252 }
00253
00254 void TrainInterceptor::init()
00255 {
00256 edm::LogInfo("MVATrainer")
00257 << "TrainProcessor \"" << (const char*)proc->getName()
00258 << "\" training iteration starting...";
00259
00260 proc->doTrainBegin();
00261 }
00262
00263 double
00264 TrainInterceptor::intercept(const std::vector<double> *values) const
00265 {
00266 if (values[targetIdx].size() != 1) {
00267 if (values[targetIdx].size() == 0)
00268 throw cms::Exception("MVATrainer")
00269 << "Trainer input lacks target variable."
00270 << std::endl;
00271 else
00272 throw cms::Exception("MVATrainer")
00273 << "Multiple targets supplied in input."
00274 << std::endl;
00275 }
00276 double target = values[targetIdx].front();
00277
00278 double weight = 1.0;
00279 if (values[weightIdx].size() > 1)
00280 throw cms::Exception("MVATrainer")
00281 << "Multiple weights supplied in input."
00282 << std::endl;
00283 else if (values[weightIdx].size() == 1)
00284 weight = values[weightIdx].front();
00285
00286 if (tmp.empty())
00287 proc->doTrainData(values + 2, target > 0.5, weight,
00288 calib->useForTraining(),
00289 calib->useForTesting());
00290 else {
00291 std::vector<std::vector<double> >::iterator pos = tmp.begin();
00292 for(unsigned int i = 0; pos != tmp.end(); i++)
00293 if (i != targetIdx && i != weightIdx)
00294 *pos++ = values[i];
00295
00296 proc->doTrainData(&tmp.front(), target > 0.5, weight,
00297 calib->useForTraining(),
00298 calib->useForTesting());
00299 }
00300
00301 return target;
00302 }
00303
00304 void TrainInterceptor::finish(bool save)
00305 {
00306 proc->doTrainEnd();
00307
00308 edm::LogInfo("MVATrainer")
00309 << "... processor \"" << (const char*)proc->getName()
00310 << "\" training iteration done.";
00311
00312 if (proc->isTrained()) {
00313 edm::LogInfo("MVATrainer")
00314 << "* Completed training of \""
00315 << (const char*)proc->getName() << "\".";
00316
00317 if (save)
00318 proc->save();
00319 }
00320 }
00321
00322
00323
00324 MVATrainerComputer::MVATrainerComputer(const std::vector<Interceptor>
00325 &interceptors, bool autoSave,
00326 UInt_t seed, double split) :
00327 interceptors(interceptors), nConfigured(0), doAutoSave(autoSave),
00328 random(seed), split(split)
00329 {
00330 for(std::vector<Interceptor>::const_iterator iter =
00331 interceptors.begin(); iter != interceptors.end(); ++iter)
00332 iter->second->setCalibration(this);
00333 }
00334
00335 MVATrainerComputer::~MVATrainerComputer()
00336 {
00337 done();
00338
00339 for(std::vector<Interceptor>::const_iterator iter =
00340 interceptors.begin(); iter != interceptors.end(); ++iter)
00341 delete iter->second;
00342 }
00343
00344 std::vector<Calibration::VarProcessor*>
00345 MVATrainerComputer::getProcessors() const
00346 {
00347 std::vector<Calibration::VarProcessor*> processors =
00348 Calibration::MVAComputer::getProcessors();
00349
00350 for(std::vector<Interceptor>::const_iterator iter =
00351 interceptors.begin(); iter != interceptors.end(); ++iter)
00352
00353 processors.insert(processors.begin() + iter->first,
00354 1, iter->second);
00355
00356 return processors;
00357 }
00358
00359 void MVATrainerComputer::initFlags(std::vector<Variable::Flags> &flags) const
00360 {
00361 assert(flags.size() == this->flags.size());
00362 flags = this->flags;
00363 }
00364
00365 void MVATrainerComputer::configured(BaseInterceptor *interceptor) const
00366 {
00367 nConfigured++;
00368 if (isConfigured())
00369 for(std::vector<Interceptor>::const_iterator iter =
00370 interceptors.begin();
00371 iter != interceptors.end(); ++iter)
00372 iter->second->init();
00373 }
00374
00375 void MVATrainerComputer::next()
00376 {
00377 splitResult = random.Uniform(1.0) >= split;
00378 }
00379
00380 void MVATrainerComputer::done()
00381 {
00382 if (isConfigured()) {
00383 for(std::vector<Interceptor>::const_iterator iter =
00384 interceptors.begin();
00385 iter != interceptors.end(); ++iter)
00386 iter->second->finish(doAutoSave);
00387 nConfigured = 0;
00388 }
00389 }
00390
00391
00392
00393 const AtomicId MVATrainer::kTargetId("__TARGET__");
00394 const AtomicId MVATrainer::kWeightId("__WEIGHT__");
00395
00396 static const AtomicId kOutputId("__OUTPUT__");
00397
00398 static bool isMagic(AtomicId id)
00399 {
00400 return id == MVATrainer::kTargetId ||
00401 id == MVATrainer::kWeightId ||
00402 id == kOutputId;
00403 }
00404
00405 static std::string escape(const std::string &in)
00406 {
00407 std::string result("'");
00408 for(std::string::const_iterator iter = in.begin();
00409 iter != in.end(); ++iter) {
00410 switch(*iter) {
00411 case '\'':
00412 result += "'\\''";
00413 break;
00414 default:
00415 result += *iter;
00416 }
00417 }
00418 result += '\'';
00419 return result;
00420 }
00421
00422 MVATrainer::MVATrainer(const std::string &fileName, bool useXSLT,
00423 const char *styleSheet) :
00424 input(0), output(0), name("MVATrainer"),
00425 doAutoSave(true), doCleanup(false),
00426 doMonitoring(false), randomSeed(65539), crossValidation(0.0)
00427 {
00428 if (useXSLT) {
00429 std::string sheet;
00430 if (!styleSheet)
00431 sheet = edm::FileInPath(
00432 "PhysicsTools/MVATrainer/data/MVATrainer.xsl")
00433 .fullPath();
00434 else
00435 sheet = styleSheet;
00436
00437 std::string preproc = "xsltproc --xinclude " + escape(sheet) +
00438 " " + escape(fileName);
00439 xml.reset(new XMLDocument(fileName, preproc));
00440 } else
00441 xml.reset(new XMLDocument(fileName));
00442
00443 DOMNode *node = xml->getRootNode();
00444
00445 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "MVATrainer") != 0)
00446 throw cms::Exception("MVATrainer")
00447 << "Invalid XML root node." << std::endl;
00448
00449 enum State {
00450 STATE_GENERAL,
00451 STATE_FIRST,
00452 STATE_MIDDLE,
00453 STATE_LAST
00454 } state = STATE_GENERAL;
00455
00456 for(node = node->getFirstChild();
00457 node; node = node->getNextSibling()) {
00458 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00459 continue;
00460
00461 std::string name = XMLSimpleStr(node->getNodeName());
00462 DOMElement *elem = static_cast<DOMElement*>(node);
00463
00464 switch(state) {
00465 case STATE_GENERAL: {
00466 if (name != "general")
00467 throw cms::Exception("MVATrainer")
00468 << "Expected general config as first "
00469 "tag." << std::endl;
00470
00471 for(DOMNode *subNode = elem->getFirstChild();
00472 subNode; subNode = subNode->getNextSibling()) {
00473 if (subNode->getNodeType() !=
00474 DOMNode::ELEMENT_NODE)
00475 continue;
00476
00477 if (std::strcmp(XMLSimpleStr(
00478 subNode->getNodeName()), "option") != 0)
00479 throw cms::Exception("MVATrainer")
00480 << "Expected option tag."
00481 << std::endl;
00482
00483 elem = static_cast<DOMElement*>(subNode);
00484 name = XMLDocument::readAttribute<std::string>(
00485 elem, "name");
00486 std::string content = XMLSimpleStr(
00487 elem->getTextContent());
00488
00489 if (name == "id")
00490 this->name = content;
00491 else if (name == "trainfiles")
00492 trainFileMask = content;
00493 else
00494 throw cms::Exception("MVATrainer")
00495 << "Unknown option \""
00496 << name << "\"." << std::endl;
00497 }
00498
00499 state = STATE_FIRST;
00500 } break;
00501 case STATE_FIRST: {
00502 if (name != "input")
00503 throw cms::Exception("MVATrainer")
00504 << "Expected input config as second "
00505 "tag." << std::endl;
00506
00507 AtomicId id = XMLDocument::readAttribute<std::string>(
00508 elem, "id");
00509 input = new Source(id, true);
00510 input->getOutputs().append(
00511 createVariable(input, kTargetId,
00512 Variable::FLAG_NONE),
00513 SourceVariableSet::kTarget);
00514 input->getOutputs().append(
00515 createVariable(input, kWeightId,
00516 Variable::FLAG_OPTIONAL),
00517 SourceVariableSet::kWeight);
00518 sources.insert(std::make_pair(id, input));
00519 fillOutputVars(input->getOutputs(), input, elem);
00520
00521 state = STATE_MIDDLE;
00522 } break;
00523 case STATE_MIDDLE: {
00524 if (name == "output") {
00525 AtomicId zero;
00526 output = new TrainProcessor("output",
00527 &zero, this);
00528 fillInputVars(output->getInputs(), elem);
00529 state = STATE_LAST;
00530 continue;
00531 } else if (name != "processor")
00532 throw cms::Exception("MVATrainer")
00533 << "Unexpected tag after input "
00534 "config." << std::endl;
00535
00536 AtomicId id = XMLDocument::readAttribute<std::string>(
00537 elem, "id");
00538 std::string name =
00539 XMLDocument::readAttribute<std::string>(
00540 elem, "name");
00541
00542 makeProcessor(elem, id, name.c_str());
00543 } break;
00544 case STATE_LAST:
00545 throw cms::Exception("MVATrainer")
00546 << "Unexpected tag found after output."
00547 << std::endl;
00548 break;
00549 }
00550 }
00551
00552 if (state == STATE_FIRST)
00553 throw cms::Exception("MVATrainer")
00554 << "Expected input variable config." << std::endl;
00555 else if (state == STATE_MIDDLE)
00556 throw cms::Exception("MVATrainer")
00557 << "Expected output variable config." << std::endl;
00558
00559 if (trainFileMask.empty())
00560 trainFileMask = this->name + "_%s%s.%s";
00561 }
00562
00563 MVATrainer::~MVATrainer()
00564 {
00565 if (monitoring.get())
00566 monitoring->write();
00567
00568 for(std::map<AtomicId, Source*>::const_iterator iter = sources.begin();
00569 iter != sources.end(); iter++) {
00570 TrainProcessor *proc =
00571 dynamic_cast<TrainProcessor*>(iter->second);
00572
00573 if (proc && doCleanup)
00574 proc->cleanup();
00575
00576 delete iter->second;
00577 }
00578 delete output;
00579 std::for_each(variables.begin(), variables.end(),
00580 deleter<SourceVariable>());
00581 }
00582
00583 void MVATrainer::loadState()
00584 {
00585 for(std::vector<AtomicId>::const_iterator iter =
00586 this->processors.begin();
00587 iter != this->processors.end(); iter++) {
00588 std::map<AtomicId, Source*>::const_iterator pos =
00589 sources.find(*iter);
00590 assert(pos != sources.end());
00591 TrainProcessor *source =
00592 dynamic_cast<TrainProcessor*>(pos->second);
00593 assert(source);
00594
00595 if (source->load())
00596 edm::LogInfo("MVATrainer")
00597 << source->getId() << " configuration for \""
00598 << (const char*)source->getName()
00599 << "\" loaded from file.";
00600 }
00601 }
00602
00603 void MVATrainer::saveState()
00604 {
00605 doCleanup = false;
00606
00607 for(std::vector<AtomicId>::const_iterator iter =
00608 this->processors.begin();
00609 iter != this->processors.end(); iter++) {
00610 std::map<AtomicId, Source*>::const_iterator pos =
00611 sources.find(*iter);
00612 assert(pos != sources.end());
00613 TrainProcessor *source =
00614 dynamic_cast<TrainProcessor*>(pos->second);
00615 assert(source);
00616
00617 if (source->isTrained())
00618 source->save();
00619 }
00620 }
00621
00622 void MVATrainer::makeProcessor(DOMElement *elem, AtomicId id, const char *name)
00623 {
00624 DOMElement *xmlInput = 0;
00625 DOMElement *xmlConfig = 0;
00626 DOMElement *xmlOutput = 0;
00627 DOMElement *xmlData = 0;
00628
00629 static struct NameExpect {
00630 const char *tag;
00631 bool mandatory;
00632 DOMElement **elem;
00633 } const expect[] = {
00634 { "input", true, &xmlInput },
00635 { "config", true, &xmlConfig },
00636 { "output", true, &xmlOutput },
00637 { "data", false, &xmlData },
00638 { 0, }
00639 };
00640
00641 const NameExpect *cur = expect;
00642 for(DOMNode *node = elem->getFirstChild();
00643 node; node = node->getNextSibling()) {
00644 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00645 continue;
00646
00647 std::string tag = XMLSimpleStr(node->getNodeName());
00648 DOMElement *elem = static_cast<DOMElement*>(node);
00649
00650 if (!cur->tag)
00651 throw cms::Exception("MVATrainer")
00652 << "Superfluous tag " << tag
00653 << "encountered in processor." << std::endl;
00654 else if (tag != cur->tag && cur->mandatory)
00655 throw cms::Exception("MVATrainer")
00656 << "Expected tag " << cur->tag << ", got "
00657 << tag << " instead in processor."
00658 << std::endl;
00659 else if (tag != cur->tag) {
00660 cur++;
00661 continue;
00662 }
00663 *(cur++)->elem = elem;
00664 }
00665
00666 while(cur->tag && !cur->mandatory)
00667 cur++;
00668 if (cur->tag)
00669 throw cms::Exception("MVATrainer")
00670 << "Unexpected end of processor configuration, "
00671 << "expected tag " << cur->tag << "." << std::endl;
00672
00673 std::auto_ptr<TrainProcessor> proc(
00674 TrainProcessor::create(name, &id, this));
00675 if (!proc.get())
00676 throw cms::Exception("MVATrainer")
00677 << "Variable processor trainer " << name
00678 << " could not be instantiated. Most likely because"
00679 " the trainer plugin for \"" << name << "\""
00680 " does not exist." << std::endl;
00681
00682 if (sources.find(id) != sources.end())
00683 throw cms::Exception("MVATrainer")
00684 << "Duplicate variable processor id "
00685 << (const char*)id << "."
00686 << std::endl;
00687
00688 fillInputVars(proc->getInputs(), xmlInput);
00689 fillOutputVars(proc->getOutputs(), proc.get(), xmlOutput);
00690
00691 edm::LogInfo("MVATrainer")
00692 << "Configuring " << (const char*)proc->getId()
00693 << " \"" << (const char*)proc->getName() << "\".";
00694 proc->configure(xmlConfig);
00695
00696 sources.insert(std::make_pair(id, proc.release()));
00697 processors.push_back(id);
00698 }
00699
00700 std::string MVATrainer::trainFileName(const TrainProcessor *proc,
00701 const std::string &ext,
00702 const std::string &arg) const
00703 {
00704 std::string arg_ = arg.size() > 0 ? ("_" + arg) : "";
00705 return stdStringPrintf(trainFileMask.c_str(),
00706 (const char*)proc->getName(),
00707 arg_.c_str(), ext.c_str());
00708 }
00709
00710 TrainerMonitoring::Module *MVATrainer::bookMonitor(const std::string &name)
00711 {
00712 if (!doMonitoring)
00713 return 0;
00714
00715 if (!monitoring.get()) {
00716 std::string fileName =
00717 stdStringPrintf(trainFileMask.c_str(),
00718 "monitoring", "", "root");
00719 monitoring.reset(new TrainerMonitoring(fileName));
00720 }
00721
00722 return monitoring->book(name);
00723 }
00724
00725 SourceVariable *MVATrainer::getVariable(AtomicId source, AtomicId name) const
00726 {
00727 std::map<AtomicId, Source*>::const_iterator pos = sources.find(source);
00728 if (pos == sources.end())
00729 return 0;
00730
00731 return pos->second->getOutput(name);
00732 }
00733
00734 SourceVariable *MVATrainer::createVariable(Source *source, AtomicId name,
00735 Variable::Flags flags)
00736 {
00737 SourceVariable *var = getVariable(source->getName(), name);
00738 if (var)
00739 return 0;
00740
00741 var = new SourceVariable(source, name, flags);
00742 variables.push_back(var);
00743 return var;
00744 }
00745
00746 void MVATrainer::fillInputVars(SourceVariableSet &vars,
00747 XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
00748 {
00749 std::vector<SourceVariable*> tmp;
00750 SourceVariable *target = 0;
00751 SourceVariable *weight = 0;
00752
00753 for(DOMNode *node = xml->getFirstChild(); node;
00754 node = node->getNextSibling()) {
00755 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00756 continue;
00757
00758 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
00759 throw cms::Exception("MVATrainer")
00760 << "Invalid input variable node." << std::endl;
00761
00762 DOMElement *elem = static_cast<DOMElement*>(node);
00763
00764 AtomicId source = XMLDocument::readAttribute<std::string>(
00765 elem, "source");
00766 AtomicId name = XMLDocument::readAttribute<std::string>(
00767 elem, "name");
00768
00769 SourceVariable *var = getVariable(source, name);
00770 if (!var)
00771 throw cms::Exception("MVATrainer")
00772 << "Input variable " << (const char*)source
00773 << ":" << (const char*)name
00774 << " not found." << std::endl;
00775
00776 if (XMLDocument::readAttribute<bool>(elem, "target", false)) {
00777 if (target)
00778 throw cms::Exception("MVATrainer")
00779 << "Target variable defined twice"
00780 << std::endl;
00781 target = var;
00782 }
00783 if (XMLDocument::readAttribute<bool>(elem, "weight", false)) {
00784 if (weight)
00785 throw cms::Exception("MVATrainer")
00786 << "Weight variable defined twice"
00787 << std::endl;
00788 weight = var;
00789 }
00790
00791 tmp.push_back(var);
00792 }
00793
00794 if (!weight) {
00795 weight = input->getOutput(kWeightId);
00796 assert(weight);
00797 tmp.insert(tmp.begin() +
00798 (target == input->getOutput(kTargetId)),
00799 1, weight);
00800 }
00801 if (!target) {
00802 target = input->getOutput(kTargetId);
00803 assert(target);
00804 tmp.insert(tmp.begin(), 1, target);
00805 }
00806
00807 unsigned int n = 0;
00808 for(std::vector<SourceVariable*>::const_iterator iter = variables.begin();
00809 iter != variables.end(); iter++) {
00810 std::vector<SourceVariable*>::const_iterator pos =
00811 std::find(tmp.begin(), tmp.end(), *iter);
00812 if (pos == tmp.end())
00813 continue;
00814
00815 SourceVariableSet::Magic magic;
00816 if (*iter == target)
00817 magic = SourceVariableSet::kTarget;
00818 else if (*iter == weight)
00819 magic = SourceVariableSet::kWeight;
00820 else
00821 magic = SourceVariableSet::kRegular;
00822
00823 if (vars.append(*iter, magic, pos - tmp.begin())) {
00824 AtomicId source = (*iter)->getSource()->getName();
00825 AtomicId name = (*iter)->getName();
00826 throw cms::Exception("MVATrainer")
00827 << "Input variable " << (const char*)source
00828 << ":" << (const char*)name
00829 << " defined twice." << std::endl;
00830 }
00831
00832 n++;
00833 }
00834
00835 assert(tmp.size() == n);
00836 }
00837
00838 void MVATrainer::fillOutputVars(SourceVariableSet &vars, Source *source,
00839 XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
00840 {
00841 for(DOMNode *node = xml->getFirstChild(); node;
00842 node = node->getNextSibling()) {
00843 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00844 continue;
00845
00846 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
00847 throw cms::Exception("MVATrainer")
00848 << "Invalid output variable node."
00849 << std::endl;
00850
00851 DOMElement *elem = static_cast<DOMElement*>(node);
00852
00853 AtomicId name = XMLDocument::readAttribute<std::string>(
00854 elem, "name");
00855 if (!name)
00856 throw cms::Exception("MVATrainer")
00857 << "Output variable tag missing name."
00858 << std::endl;
00859 if (isMagic(name))
00860 throw cms::Exception("MVATrainer")
00861 << "Cannot use magic variable names in output."
00862 << std::endl;
00863
00864 Variable::Flags flags = Variable::FLAG_NONE;
00865
00866 if (XMLDocument::readAttribute<bool>(elem, "optional", true))
00867 flags = (PhysicsTools::Variable::Flags)
00868 (flags | Variable::FLAG_OPTIONAL);
00869
00870 if (XMLDocument::readAttribute<bool>(elem, "multiple", true))
00871 flags = (PhysicsTools::Variable::Flags)
00872 (flags | Variable::FLAG_MULTIPLE);
00873
00874 SourceVariable *var = createVariable(source, name, flags);
00875 if (!var || vars.append(var))
00876 throw cms::Exception("MVATrainer")
00877 << "Output variable "
00878 << (const char*)source->getName()
00879 << ":" << (const char*)name
00880 << " defined twice." << std::endl;
00881 }
00882 }
00883
00884 void
00885 MVATrainer::connectProcessors(Calibration::MVAComputer *calib,
00886 const std::vector<CalibratedProcessor> &procs,
00887 bool withTarget) const
00888 {
00889 std::map<SourceVariable*, unsigned int> vars;
00890 unsigned int size = 0;
00891
00892 MVATrainerComputer *trainCalib =
00893 dynamic_cast<MVATrainerComputer*>(calib);
00894
00895 for(unsigned int i = 0;
00896 i < input->getOutputs().size(true); i++) {
00897 if (i < 2 && !withTarget)
00898 continue;
00899
00900 SourceVariable *var = variables[i];
00901 vars[var] = size++;
00902
00903 Calibration::Variable calibVar;
00904 calibVar.name = (const char*)var->getName();
00905 calib->inputSet.push_back(calibVar);
00906 if (trainCalib)
00907 trainCalib->addFlag(var->getFlags());
00908 }
00909
00910 for(std::vector<CalibratedProcessor>::const_iterator iter =
00911 procs.begin(); iter != procs.end(); iter++) {
00912 bool isInterceptor = dynamic_cast<BaseInterceptor*>(
00913 iter->calib) != 0;
00914
00915 BitSet inputSet(size);
00916
00917 unsigned int last = 0;
00918 std::vector<SourceVariable*> inoutVars;
00919 if (iter->processor)
00920 inoutVars = iter->processor->getInputs().get(
00921 isInterceptor);
00922 for(std::vector<SourceVariable*>::const_iterator iter2 =
00923 inoutVars.begin(); iter2 != inoutVars.end(); iter2++) {
00924 std::map<SourceVariable*,
00925 unsigned int>::const_iterator pos =
00926 vars.find(*iter2);
00927
00928 assert(pos != vars.end());
00929
00930 if (pos->second < last)
00931 throw cms::Exception("MVATrainer")
00932 << "Input variables not declared "
00933 "in order of appearance in \""
00934 << (const char*)iter->processor->getName()
00935 << "\"." << std::endl;
00936
00937 inputSet[last = pos->second] = true;
00938 }
00939
00940 assert(!isInterceptor || withTarget);
00941
00942 iter->calib->inputVars = Calibration::convert(inputSet);
00943
00944 calib->output = size;
00945
00946 if (isInterceptor) {
00947 size++;
00948 continue;
00949 }
00950
00951 calib->addProcessor(iter->calib);
00952
00953 inoutVars = iter->processor->getOutputs().get();
00954 for(std::vector<SourceVariable*>::const_iterator iter =
00955 inoutVars.begin(); iter != inoutVars.end(); iter++) {
00956
00957 vars[*iter] = size++;
00958 }
00959 }
00960
00961 if (output->getInputs().size() != 1)
00962 throw cms::Exception("MVATrainer")
00963 << "Exactly one output variable has to be specified."
00964 << std::endl;
00965
00966 SourceVariable *outVar = output->getInputs().get()[0];
00967 std::map<SourceVariable*, unsigned int>::const_iterator pos =
00968 vars.find(outVar);
00969 if (pos != vars.end())
00970 calib->output = pos->second;
00971 }
00972
00973 Calibration::MVAComputer *
00974 MVATrainer::makeTrainCalibration(const AtomicId *compute,
00975 const AtomicId *train) const
00976 {
00977 std::map<AtomicId, TrainInterceptor*> interceptors;
00978 std::vector<MVATrainerComputer::Interceptor> baseInterceptors;
00979 std::vector<CalibratedProcessor> processors;
00980
00981 BaseInterceptor *interceptor = new InitInterceptor;
00982 baseInterceptors.push_back(std::make_pair(0, interceptor));
00983 processors.push_back(CalibratedProcessor(0, interceptor));
00984
00985 for(const AtomicId *iter = train; *iter; iter++) {
00986 TrainProcessor *source;
00987 if (*iter == kOutputId)
00988 source = output;
00989 else {
00990 std::map<AtomicId, Source*>::const_iterator pos =
00991 sources.find(*iter);
00992 assert(pos != sources.end());
00993 source = dynamic_cast<TrainProcessor*>(pos->second);
00994 }
00995 assert(source);
00996
00997 interceptors[*iter] = new TrainInterceptor(source);
00998 }
00999
01000 auto_cleaner<Calibration::VarProcessor> autoClean;
01001
01002 std::set<AtomicId> done;
01003 for(const AtomicId *iter = compute; *iter; iter++) {
01004 if (done.erase(*iter))
01005 continue;
01006
01007 std::map<AtomicId, Source*>::const_iterator pos =
01008 sources.find(*iter);
01009 assert(pos != sources.end());
01010 TrainProcessor *source =
01011 dynamic_cast<TrainProcessor*>(pos->second);
01012 assert(source);
01013 assert(source->isTrained());
01014
01015 Calibration::VarProcessor *proc = source->getCalibration();
01016 if (!proc)
01017 continue;
01018
01019 autoClean.add(proc);
01020 processors.push_back(CalibratedProcessor(source, proc));
01021
01022 Calibration::ProcForeach *looper =
01023 dynamic_cast<Calibration::ProcForeach*>(proc);
01024 if (looper) {
01025 std::vector<AtomicId>::const_iterator pos2 =
01026 std::find(this->processors.begin(),
01027 this->processors.end(), *iter);
01028 assert(pos2 != this->processors.end());
01029 ++pos2;
01030 unsigned int n = 0;
01031 for(int i = 0; i < (int)looper->nProcs; ++i, ++pos2) {
01032 assert(pos2 != this->processors.end());
01033
01034 const AtomicId *iter2 = compute;
01035 while(*iter2) {
01036 if (*iter2 == *pos2)
01037 break;
01038 iter2++;
01039 }
01040
01041 if (*iter2) {
01042 n++;
01043 done.insert(*iter2);
01044 pos = sources.find(*iter2);
01045 assert(pos != sources.end());
01046 TrainProcessor *source =
01047 dynamic_cast<TrainProcessor*>(
01048 pos->second);
01049 assert(source);
01050 assert(source->isTrained());
01051
01052 proc = source->getCalibration();
01053 if (proc) {
01054 autoClean.add(proc);
01055 processors.push_back(
01056 CalibratedProcessor(
01057 source, proc));
01058 }
01059 }
01060
01061 std::map<AtomicId, TrainInterceptor*>::iterator
01062 pos3 = interceptors.find(*pos2);
01063 if (pos3 != interceptors.end()) {
01064 n++;
01065 baseInterceptors.push_back(
01066 std::make_pair(processors.size(),
01067 pos3->second));
01068 processors.push_back(
01069 CalibratedProcessor(
01070 pos3->second->getProcessor(),
01071 pos3->second));
01072 interceptors.erase(pos3);
01073 }
01074 }
01075
01076 looper->nProcs = n;
01077 if (!n) {
01078 baseInterceptors.pop_back();
01079 processors.pop_back();
01080 }
01081 }
01082 }
01083
01084 for(std::map<AtomicId, TrainInterceptor*>::const_iterator iter =
01085 interceptors.begin(); iter != interceptors.end(); ++iter) {
01086
01087 TrainProcessor *proc = iter->second->getProcessor();
01088 baseInterceptors.push_back(std::make_pair(processors.size(),
01089 iter->second));
01090 processors.push_back(CalibratedProcessor(proc, iter->second));
01091 }
01092
01093 std::auto_ptr<Calibration::MVAComputer> calib(
01094 new MVATrainerComputer(baseInterceptors, doAutoSave,
01095 randomSeed, crossValidation));
01096
01097 connectProcessors(calib.get(), processors, true);
01098
01099 return calib.release();
01100 }
01101
01102 void MVATrainer::doneTraining(Calibration::MVAComputer *trainCalibration) const
01103 {
01104 MVATrainerComputer *calib =
01105 dynamic_cast<MVATrainerComputer*>(trainCalibration);
01106
01107 if (!calib)
01108 throw cms::Exception("MVATrainer")
01109 << "Invalid training calibration passed to "
01110 "doneTraining()" << std::endl;
01111
01112 calib->done();
01113 }
01114
01115 std::vector<AtomicId> MVATrainer::findFinalProcessors() const
01116 {
01117 std::set<Source*> toCheck;
01118 toCheck.insert(output);
01119
01120 std::set<Source*> done;
01121 while(!toCheck.empty()) {
01122 Source *source = *toCheck.begin();
01123 toCheck.erase(toCheck.begin());
01124
01125 std::vector<SourceVariable*> inputs = source->inputs.get();
01126 for(std::vector<SourceVariable*>::const_iterator iter =
01127 inputs.begin(); iter != inputs.end(); ++iter) {
01128 source = (*iter)->getSource();
01129 if (done.insert(source).second)
01130 toCheck.insert(source);
01131 }
01132 }
01133
01134 std::vector<AtomicId> result;
01135 for(std::vector<AtomicId>::const_iterator iter = processors.begin();
01136 iter != processors.end(); ++iter) {
01137 std::map<AtomicId, Source*>::const_iterator pos =
01138 sources.find(*iter);
01139 if (pos != sources.end() && done.count(pos->second))
01140 result.push_back(*iter);
01141 }
01142
01143 return result;
01144 }
01145
01146 Calibration::MVAComputer *MVATrainer::getCalibration() const
01147 {
01148 std::vector<CalibratedProcessor> processors;
01149
01150 std::auto_ptr<Calibration::MVAComputer> calib(
01151 new Calibration::MVAComputer);
01152
01153 std::vector<AtomicId> used = findFinalProcessors();
01154 for(std::vector<AtomicId>::const_iterator iter = used.begin();
01155 iter != used.end(); iter++) {
01156 std::map<AtomicId, Source*>::const_iterator pos =
01157 sources.find(*iter);
01158 assert(pos != sources.end());
01159 TrainProcessor *source =
01160 dynamic_cast<TrainProcessor*>(pos->second);
01161 assert(source);
01162 if (!source->isTrained())
01163 return 0;
01164
01165 Calibration::VarProcessor *proc = source->getCalibration();
01166 if (!proc)
01167 continue;
01168
01169 Calibration::ProcForeach *foreach =
01170 dynamic_cast<Calibration::ProcForeach*>(proc);
01171 if (foreach) {
01172 std::vector<AtomicId>::const_iterator begin =
01173 std::find(this->processors.begin(),
01174 this->processors.end(), *iter);
01175 assert(this->processors.end() - begin >
01176 (int)(foreach->nProcs + 1));
01177 ++begin;
01178 std::vector<AtomicId>::const_iterator end =
01179 begin + foreach->nProcs;
01180 foreach->nProcs = 0;
01181 for(std::vector<AtomicId>::const_iterator iter2 =
01182 iter; iter2 != used.end(); ++iter2)
01183 if (std::find(begin, end, *iter2) != end)
01184 foreach->nProcs++;
01185 }
01186
01187 processors.push_back(CalibratedProcessor(source, proc));
01188 }
01189
01190 connectProcessors(calib.get(), processors, false);
01191
01192 return calib.release();
01193 }
01194
01195 void MVATrainer::findUntrainedComputers(std::vector<AtomicId> &compute,
01196 std::vector<AtomicId> &train) const
01197 {
01198 compute.clear();
01199 train.clear();
01200
01201 std::set<Source*> trainedSources;
01202 trainedSources.insert(input);
01203
01204 for(std::vector<AtomicId>::const_iterator iter =
01205 processors.begin(); iter != processors.end(); iter++) {
01206 std::map<AtomicId, Source*>::const_iterator pos =
01207 sources.find(*iter);
01208 assert(pos != sources.end());
01209 TrainProcessor *proc =
01210 dynamic_cast<TrainProcessor*>(pos->second);
01211 assert(proc);
01212
01213 bool trainedDeps = true;
01214 std::vector<SourceVariable*> inputVars =
01215 proc->getInputs().get();
01216 for(std::vector<SourceVariable*>::const_iterator iter2 =
01217 inputVars.begin(); iter2 != inputVars.end(); iter2++) {
01218 if (trainedSources.find((*iter2)->getSource())
01219 == trainedSources.end()) {
01220 trainedDeps = false;
01221 break;
01222 }
01223 }
01224
01225 if (!trainedDeps)
01226 continue;
01227
01228 if (proc->isTrained()) {
01229 trainedSources.insert(proc);
01230 compute.push_back(proc->getName());
01231 } else
01232 train.push_back(proc->getName());
01233 }
01234
01235 if (doMonitoring && !output->isTrained() &&
01236 trainedSources.find(output->getInputs().get()[0]->getSource())
01237 != trainedSources.end())
01238 train.push_back(kOutputId);
01239 }
01240
01241 Calibration::MVAComputer *MVATrainer::getTrainCalibration() const
01242 {
01243 std::vector<AtomicId> compute, train;
01244 findUntrainedComputers(compute, train);
01245
01246 if (train.empty())
01247 return 0;
01248
01249 compute.push_back(0);
01250 train.push_back(0);
01251
01252 return makeTrainCalibration(&compute.front(), &train.front());
01253 }
01254
01255 }