00001 #include <assert.h>
00002 #include <functional>
00003 #include <ext/functional>
00004 #include <algorithm>
00005 #include <iostream>
00006 #include <cstdarg>
00007 #include <cstring>
00008 #include <cstdio>
00009 #include <string>
00010 #include <vector>
00011 #include <memory>
00012 #include <map>
00013 #include <set>
00014
00015 #include <xercesc/dom/DOM.hpp>
00016
00017 #include <TRandom.h>
00018
00019 #include "FWCore/Utilities/interface/Exception.h"
00020 #include "FWCore/MessageLogger/interface/MessageLogger.h"
00021
00022 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00023 #include "PhysicsTools/MVAComputer/interface/BitSet.h"
00024 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00025 #include "PhysicsTools/MVAComputer/interface/Variable.h"
00026
00027 #include "PhysicsTools/MVATrainer/interface/Interceptor.h"
00028 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00029 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00030 #include "PhysicsTools/MVATrainer/interface/XMLUniStr.h"
00031 #include "PhysicsTools/MVATrainer/interface/Source.h"
00032 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00033 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00034 #include "PhysicsTools/MVATrainer/interface/TrainerMonitoring.h"
00035 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00036
00037 XERCES_CPP_NAMESPACE_USE
00038
00039 namespace PhysicsTools {
00040
00041 namespace {
00042 class MVATrainerComputer;
00043
00044 class BaseInterceptor : public Calibration::Interceptor {
00045 public:
00046 BaseInterceptor() : calib(0) {}
00047 virtual ~BaseInterceptor() {}
00048
00049 inline void setCalibration(MVATrainerComputer *calib)
00050 { this->calib = calib; }
00051
00052 virtual std::vector<Variable::Flags>
00053 configure(const MVAComputer *computer, unsigned int n,
00054 const std::vector<Variable::Flags> &flags) = 0;
00055
00056 virtual double
00057 intercept(const std::vector<double> *values) const = 0;
00058
00059 virtual void init() {}
00060 virtual void finish(bool save) {}
00061
00062 protected:
00063 MVATrainerComputer *calib;
00064 };
00065
00066 class InitInterceptor : public BaseInterceptor {
00067 public:
00068 InitInterceptor() {}
00069 virtual ~InitInterceptor() {}
00070
00071 virtual std::vector<Variable::Flags>
00072 configure(const MVAComputer *computer, unsigned int n,
00073 const std::vector<Variable::Flags> &flags);
00074
00075 virtual double
00076 intercept(const std::vector<double> *values) const;
00077 };
00078
00079 class TrainInterceptor : public BaseInterceptor {
00080 public:
00081 TrainInterceptor(TrainProcessor *proc) : proc(proc) {}
00082 virtual ~TrainInterceptor() {}
00083
00084 inline TrainProcessor *getProcessor() const { return proc; }
00085
00086 virtual std::vector<Variable::Flags>
00087 configure(const MVAComputer *computer, unsigned int n,
00088 const std::vector<Variable::Flags> &flags);
00089
00090 virtual double
00091 intercept(const std::vector<double> *values) const;
00092
00093 virtual void init();
00094 virtual void finish(bool save);
00095
00096 private:
00097 unsigned int targetIdx;
00098 unsigned int weightIdx;
00099 mutable std::vector<std::vector<double> > tmp;
00100 TrainProcessor *const proc;
00101 };
00102
00103 class MVATrainerComputer : public TrainMVAComputerCalibration {
00104 public:
00105 typedef std::pair<unsigned int, BaseInterceptor*> Interceptor;
00106
00107 MVATrainerComputer(const std::vector<Interceptor>
00108 &interceptors,
00109 bool autoSave, UInt_t seed, double split);
00110
00111 virtual ~MVATrainerComputer();
00112
00113 virtual std::vector<Calibration::VarProcessor*>
00114 getProcessors() const;
00115 virtual void initFlags(std::vector<Variable::Flags>
00116 &flags) const;
00117
00118 void configured(BaseInterceptor *interceptor) const;
00119 void next();
00120 void done();
00121
00122 inline void addFlag(Variable::Flags flag)
00123 { flags.push_back(flag); }
00124
00125 inline bool useForTraining() const { return splitResult; }
00126 inline bool useForTesting() const
00127 { return split <= 0.0 || !splitResult; }
00128
00129 inline bool isConfigured() const
00130 { return nConfigured == interceptors.size(); }
00131
00132 private:
00133 std::vector<Interceptor> interceptors;
00134 std::vector<Variable::Flags> flags;
00135 mutable unsigned int nConfigured;
00136 bool doAutoSave;
00137 TRandom random;
00138 double split;
00139 bool splitResult;
00140 };
00141
00142
00143
00144 template<typename T>
00145 struct deleter : public std::unary_function<T*, void> {
00146 inline void operator() (T *ptr) const { delete ptr; }
00147 };
00148
00149 template<typename T>
00150 struct auto_cleaner {
00151 inline ~auto_cleaner()
00152 { std::for_each(clean.begin(), clean.end(), deleter<T>()); }
00153
00154 inline void add(T *ptr) { clean.push_back(ptr); }
00155 std::vector<T*> clean;
00156 };
00157 }
00158
00159 static std::string stdStringVPrintf(const char *format, std::va_list va)
00160 {
00161 unsigned int size = std::min<unsigned int>(128, std::strlen(format));
00162 char *buffer = new char[size];
00163 for(;;) {
00164 int n = std::vsnprintf(buffer, size, format, va);
00165 if (n >= 0 && (unsigned int)n < size)
00166 break;
00167
00168 if (n >= 0)
00169 size = n + 1;
00170 else
00171 size *= 2;
00172
00173 delete[] buffer;
00174 buffer = new char[size];
00175 }
00176
00177 std::string result(buffer);
00178 delete[] buffer;
00179 return result;
00180 }
00181
00182 static std::string stdStringPrintf(const char *format, ...)
00183 {
00184 std::va_list va;
00185 va_start(va, format);
00186 std::string result = stdStringVPrintf(format, va);
00187 va_end(va);
00188 return result;
00189 }
00190
00191
00192
00193 std::vector<Variable::Flags>
00194 InitInterceptor::configure(const MVAComputer *computer, unsigned int n,
00195 const std::vector<Variable::Flags> &flags)
00196 {
00197 calib->configured(this);
00198 return std::vector<Variable::Flags>(n, Variable::FLAG_NONE);
00199 }
00200
00201 double
00202 InitInterceptor::intercept(const std::vector<double> *values) const
00203 {
00204 calib->next();
00205 return 0.0;
00206 }
00207
00208
00209
00210 std::vector<Variable::Flags>
00211 TrainInterceptor::configure(const MVAComputer *computer, unsigned int n,
00212 const std::vector<Variable::Flags> &flags)
00213 {
00214 const SourceVariableSet &inputSet =
00215 const_cast<const TrainProcessor*>(proc)->getInputs();
00216 SourceVariable *target = inputSet.find(SourceVariableSet::kTarget);
00217 SourceVariable *weight = inputSet.find(SourceVariableSet::kWeight);
00218
00219 std::vector<SourceVariable*> inputs = inputSet.get(true);
00220
00221 std::vector<SourceVariable*>::const_iterator pos;
00222 pos = std::find(inputs.begin(), inputs.end(), target);
00223 assert(pos != inputs.end());
00224 targetIdx = pos - inputs.begin();
00225 pos = std::find(inputs.begin(), inputs.end(), weight);
00226 assert(pos != inputs.end());
00227 weightIdx = pos - inputs.begin();
00228
00229 calib->configured(this);
00230
00231 std::vector<Variable::Flags> result = flags;
00232 if (targetIdx < weightIdx) {
00233 result.erase(result.begin() + weightIdx);
00234 result.erase(result.begin() + targetIdx);
00235 } else {
00236 result.erase(result.begin() + targetIdx);
00237 result.erase(result.begin() + weightIdx);
00238 }
00239
00240 proc->passFlags(result);
00241
00242 result.clear();
00243 result.resize(n, proc->getDefaultFlags());
00244 result[targetIdx] = Variable::FLAG_NONE;
00245 result[weightIdx] = Variable::FLAG_OPTIONAL;
00246
00247 if (targetIdx >= 2 || weightIdx >= 2)
00248 tmp.resize(n - 2);
00249
00250 return result;
00251 }
00252
00253 void TrainInterceptor::init()
00254 {
00255 edm::LogInfo("MVATrainer")
00256 << "TrainProcessor \"" << (const char*)proc->getName()
00257 << "\" training iteration starting...";
00258
00259 proc->doTrainBegin();
00260 }
00261
00262 double
00263 TrainInterceptor::intercept(const std::vector<double> *values) const
00264 {
00265 if (values[targetIdx].size() != 1) {
00266 if (values[targetIdx].size() == 0)
00267 throw cms::Exception("MVATrainer")
00268 << "Trainer input lacks target variable."
00269 << std::endl;
00270 else
00271 throw cms::Exception("MVATrainer")
00272 << "Multiple targets supplied in input."
00273 << std::endl;
00274 }
00275 double target = values[targetIdx].front();
00276
00277 double weight = 1.0;
00278 if (values[weightIdx].size() > 1)
00279 throw cms::Exception("MVATrainer")
00280 << "Multiple weights supplied in input."
00281 << std::endl;
00282 else if (values[weightIdx].size() == 1)
00283 weight = values[weightIdx].front();
00284
00285 if (tmp.empty())
00286 proc->doTrainData(values + 2, target > 0.5, weight,
00287 calib->useForTraining(),
00288 calib->useForTesting());
00289 else {
00290 std::vector<std::vector<double> >::iterator pos = tmp.begin();
00291 for(unsigned int i = 0; pos != tmp.end(); i++)
00292 if (i != targetIdx && i != weightIdx)
00293 *pos++ = values[i];
00294
00295 proc->doTrainData(&tmp.front(), target > 0.5, weight,
00296 calib->useForTraining(),
00297 calib->useForTesting());
00298 }
00299
00300 return target;
00301 }
00302
00303 void TrainInterceptor::finish(bool save)
00304 {
00305 proc->doTrainEnd();
00306
00307 edm::LogInfo("MVATrainer")
00308 << "... processor \"" << (const char*)proc->getName()
00309 << "\" training iteration done.";
00310
00311 if (proc->isTrained()) {
00312 edm::LogInfo("MVATrainer")
00313 << "* Completed training of \""
00314 << (const char*)proc->getName() << "\".";
00315
00316 if (save)
00317 proc->save();
00318 }
00319 }
00320
00321
00322
00323 MVATrainerComputer::MVATrainerComputer(const std::vector<Interceptor>
00324 &interceptors, bool autoSave,
00325 UInt_t seed, double split) :
00326 interceptors(interceptors), nConfigured(0), doAutoSave(autoSave),
00327 random(seed), split(split)
00328 {
00329 for(std::vector<Interceptor>::const_iterator iter =
00330 interceptors.begin(); iter != interceptors.end(); ++iter)
00331 iter->second->setCalibration(this);
00332 }
00333
00334 MVATrainerComputer::~MVATrainerComputer()
00335 {
00336 done();
00337
00338 for(std::vector<Interceptor>::const_iterator iter =
00339 interceptors.begin(); iter != interceptors.end(); ++iter)
00340 delete iter->second;
00341 }
00342
00343 std::vector<Calibration::VarProcessor*>
00344 MVATrainerComputer::getProcessors() const
00345 {
00346 std::vector<Calibration::VarProcessor*> processors =
00347 Calibration::MVAComputer::getProcessors();
00348
00349 for(std::vector<Interceptor>::const_iterator iter =
00350 interceptors.begin(); iter != interceptors.end(); ++iter)
00351
00352 processors.insert(processors.begin() + iter->first,
00353 1, iter->second);
00354
00355 return processors;
00356 }
00357
00358 void MVATrainerComputer::initFlags(std::vector<Variable::Flags> &flags) const
00359 {
00360 assert(flags.size() == this->flags.size());
00361 flags = this->flags;
00362 }
00363
00364 void MVATrainerComputer::configured(BaseInterceptor *interceptor) const
00365 {
00366 nConfigured++;
00367 if (isConfigured())
00368 for(std::vector<Interceptor>::const_iterator iter =
00369 interceptors.begin();
00370 iter != interceptors.end(); ++iter)
00371 iter->second->init();
00372 }
00373
00374 void MVATrainerComputer::next()
00375 {
00376 splitResult = random.Uniform(1.0) >= split;
00377 }
00378
00379 void MVATrainerComputer::done()
00380 {
00381 if (isConfigured()) {
00382 for(std::vector<Interceptor>::const_iterator iter =
00383 interceptors.begin();
00384 iter != interceptors.end(); ++iter)
00385 iter->second->finish(doAutoSave);
00386 nConfigured = 0;
00387 }
00388 }
00389
00390
00391
00392 const AtomicId MVATrainer::kTargetId("__TARGET__");
00393 const AtomicId MVATrainer::kWeightId("__WEIGHT__");
00394
00395 static const AtomicId kOutputId("__OUTPUT__");
00396
00397 static bool isMagic(AtomicId id)
00398 {
00399 return id == MVATrainer::kTargetId ||
00400 id == MVATrainer::kWeightId ||
00401 id == kOutputId;
00402 }
00403
00404 MVATrainer::MVATrainer(const std::string &fileName) :
00405 input(0), output(0), name("MVATrainer"),
00406 doAutoSave(true), doCleanup(false), doMonitoring(false),
00407 randomSeed(65539), crossValidation(0.0)
00408 {
00409 xml = std::auto_ptr<XMLDocument>(new XMLDocument(fileName));
00410
00411 DOMNode *node = xml->getRootNode();
00412
00413 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "MVATrainer") != 0)
00414 throw cms::Exception("MVATrainer")
00415 << "Invalid XML root node." << std::endl;
00416
00417 enum State {
00418 STATE_GENERAL,
00419 STATE_FIRST,
00420 STATE_MIDDLE,
00421 STATE_LAST
00422 } state = STATE_GENERAL;
00423
00424 for(node = node->getFirstChild();
00425 node; node = node->getNextSibling()) {
00426 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00427 continue;
00428
00429 std::string name = XMLSimpleStr(node->getNodeName());
00430 DOMElement *elem = static_cast<DOMElement*>(node);
00431
00432 switch(state) {
00433 case STATE_GENERAL: {
00434 if (name != "general")
00435 throw cms::Exception("MVATrainer")
00436 << "Expected general config as first "
00437 "tag." << std::endl;
00438
00439 for(DOMNode *subNode = elem->getFirstChild();
00440 subNode; subNode = subNode->getNextSibling()) {
00441 if (subNode->getNodeType() !=
00442 DOMNode::ELEMENT_NODE)
00443 continue;
00444
00445 if (std::strcmp(XMLSimpleStr(
00446 subNode->getNodeName()), "option") != 0)
00447 throw cms::Exception("MVATrainer")
00448 << "Expected option tag."
00449 << std::endl;
00450
00451 elem = static_cast<DOMElement*>(subNode);
00452 name = XMLDocument::readAttribute<std::string>(
00453 elem, "name");
00454 std::string content = XMLSimpleStr(
00455 elem->getTextContent());
00456
00457 if (name == "id")
00458 this->name = content;
00459 else if (name == "trainfiles")
00460 trainFileMask = content;
00461 else
00462 throw cms::Exception("MVATrainer")
00463 << "Unknown option \""
00464 << name << "\"." << std::endl;
00465 }
00466
00467 state = STATE_FIRST;
00468 } break;
00469 case STATE_FIRST: {
00470 if (name != "input")
00471 throw cms::Exception("MVATrainer")
00472 << "Expected input config as second "
00473 "tag." << std::endl;
00474
00475 AtomicId id = XMLDocument::readAttribute<std::string>(
00476 elem, "id");
00477 input = new Source(id, true);
00478 input->getOutputs().append(
00479 createVariable(input, kTargetId,
00480 Variable::FLAG_NONE),
00481 SourceVariableSet::kTarget);
00482 input->getOutputs().append(
00483 createVariable(input, kWeightId,
00484 Variable::FLAG_OPTIONAL),
00485 SourceVariableSet::kWeight);
00486 sources.insert(std::make_pair(id, input));
00487 fillOutputVars(input->getOutputs(), input, elem);
00488
00489 state = STATE_MIDDLE;
00490 } break;
00491 case STATE_MIDDLE: {
00492 if (name == "output") {
00493 AtomicId zero;
00494 output = new TrainProcessor("output",
00495 &zero, this);
00496 fillInputVars(output->getInputs(), elem);
00497 state = STATE_LAST;
00498 continue;
00499 } else if (name != "processor")
00500 throw cms::Exception("MVATrainer")
00501 << "Unexpected tag after input "
00502 "config." << std::endl;
00503
00504 AtomicId id = XMLDocument::readAttribute<std::string>(
00505 elem, "id");
00506 std::string name =
00507 XMLDocument::readAttribute<std::string>(
00508 elem, "name");
00509
00510 makeProcessor(elem, id, name.c_str());
00511 } break;
00512 case STATE_LAST:
00513 throw cms::Exception("MVATrainer")
00514 << "Unexpected tag found after output."
00515 << std::endl;
00516 break;
00517 }
00518 }
00519
00520 if (state == STATE_FIRST)
00521 throw cms::Exception("MVATrainer")
00522 << "Expected input variable config." << std::endl;
00523 else if (state == STATE_MIDDLE)
00524 throw cms::Exception("MVATrainer")
00525 << "Expected output variable config." << std::endl;
00526
00527 if (trainFileMask.empty())
00528 trainFileMask = this->name + "_%s%s.%s";
00529 }
00530
00531 MVATrainer::~MVATrainer()
00532 {
00533 if (monitoring.get())
00534 monitoring->write();
00535
00536 for(std::map<AtomicId, Source*>::const_iterator iter = sources.begin();
00537 iter != sources.end(); iter++) {
00538 TrainProcessor *proc =
00539 dynamic_cast<TrainProcessor*>(iter->second);
00540
00541 if (proc && doCleanup)
00542 proc->cleanup();
00543
00544 delete iter->second;
00545 }
00546 delete output;
00547 std::for_each(variables.begin(), variables.end(),
00548 deleter<SourceVariable>());
00549 }
00550
00551 void MVATrainer::loadState()
00552 {
00553 for(std::vector<AtomicId>::const_iterator iter =
00554 this->processors.begin();
00555 iter != this->processors.end(); iter++) {
00556 std::map<AtomicId, Source*>::const_iterator pos =
00557 sources.find(*iter);
00558 assert(pos != sources.end());
00559 TrainProcessor *source =
00560 dynamic_cast<TrainProcessor*>(pos->second);
00561 assert(source);
00562
00563 if (source->load())
00564 edm::LogInfo("MVATrainer")
00565 << source->getId() << " configuration for \""
00566 << (const char*)source->getName()
00567 << "\" loaded from file.";
00568 }
00569 }
00570
00571 void MVATrainer::saveState()
00572 {
00573 doCleanup = false;
00574
00575 for(std::vector<AtomicId>::const_iterator iter =
00576 this->processors.begin();
00577 iter != this->processors.end(); iter++) {
00578 std::map<AtomicId, Source*>::const_iterator pos =
00579 sources.find(*iter);
00580 assert(pos != sources.end());
00581 TrainProcessor *source =
00582 dynamic_cast<TrainProcessor*>(pos->second);
00583 assert(source);
00584
00585 if (source->isTrained())
00586 source->save();
00587 }
00588 }
00589
00590 void MVATrainer::makeProcessor(DOMElement *elem, AtomicId id, const char *name)
00591 {
00592 DOMElement *xmlInput = 0;
00593 DOMElement *xmlConfig = 0;
00594 DOMElement *xmlOutput = 0;
00595 DOMElement *xmlData = 0;
00596
00597 static struct NameExpect {
00598 const char *tag;
00599 bool mandatory;
00600 DOMElement **elem;
00601 } const expect[] = {
00602 { "input", true, &xmlInput },
00603 { "config", true, &xmlConfig },
00604 { "output", true, &xmlOutput },
00605 { "data", false, &xmlData },
00606 { 0, }
00607 };
00608
00609 const NameExpect *cur = expect;
00610 for(DOMNode *node = elem->getFirstChild();
00611 node; node = node->getNextSibling()) {
00612 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00613 continue;
00614
00615 std::string tag = XMLSimpleStr(node->getNodeName());
00616 DOMElement *elem = static_cast<DOMElement*>(node);
00617
00618 if (!cur->tag)
00619 throw cms::Exception("MVATrainer")
00620 << "Superfluous tag " << tag
00621 << "encountered in processor." << std::endl;
00622 else if (tag != cur->tag && cur->mandatory)
00623 throw cms::Exception("MVATrainer")
00624 << "Expected tag " << cur->tag << ", got "
00625 << tag << " instead in processor."
00626 << std::endl;
00627 else if (tag != cur->tag) {
00628 cur++;
00629 continue;
00630 }
00631 *(cur++)->elem = elem;
00632 }
00633
00634 while(cur->tag && !cur->mandatory)
00635 cur++;
00636 if (cur->tag)
00637 throw cms::Exception("MVATrainer")
00638 << "Unexpected end of processor configuration, "
00639 << "expected tag " << cur->tag << "." << std::endl;
00640
00641 std::auto_ptr<TrainProcessor> proc(
00642 TrainProcessor::create(name, &id, this));
00643 if (!proc.get())
00644 throw cms::Exception("MVATrainer")
00645 << "Variable processor trainer " << name
00646 << " could not be instantiated. Most likely because"
00647 " the trainer plugin for \"" << name << "\""
00648 " \" does not exist." << std::endl;
00649
00650 if (sources.find(id) != sources.end())
00651 throw cms::Exception("MVATrainer")
00652 << "Duplicate variable processor id "
00653 << (const char*)id << "."
00654 << std::endl;
00655
00656 fillInputVars(proc->getInputs(), xmlInput);
00657 fillOutputVars(proc->getOutputs(), proc.get(), xmlOutput);
00658
00659 edm::LogInfo("MVATrainer")
00660 << "Configuring " << (const char*)proc->getId()
00661 << " \"" << (const char*)proc->getName() << "\".";
00662 proc->configure(xmlConfig);
00663
00664 sources.insert(std::make_pair(id, proc.release()));
00665 processors.push_back(id);
00666 }
00667
00668 std::string MVATrainer::trainFileName(const TrainProcessor *proc,
00669 const std::string &ext,
00670 const std::string &arg) const
00671 {
00672 std::string arg_ = arg.size() > 0 ? ("_" + arg) : "";
00673 return stdStringPrintf(trainFileMask.c_str(),
00674 (const char*)proc->getName(),
00675 arg_.c_str(), ext.c_str());
00676 }
00677
00678 TrainerMonitoring::Module *MVATrainer::bookMonitor(const std::string &name)
00679 {
00680 if (!doMonitoring)
00681 return 0;
00682
00683 if (!monitoring.get()) {
00684 std::string fileName =
00685 stdStringPrintf(trainFileMask.c_str(),
00686 "monitoring", "", "root");
00687 monitoring.reset(new TrainerMonitoring(fileName));
00688 }
00689
00690 return monitoring->book(name);
00691 }
00692
00693 SourceVariable *MVATrainer::getVariable(AtomicId source, AtomicId name) const
00694 {
00695 std::map<AtomicId, Source*>::const_iterator pos = sources.find(source);
00696 if (pos == sources.end())
00697 return 0;
00698
00699 return pos->second->getOutput(name);
00700 }
00701
00702 SourceVariable *MVATrainer::createVariable(Source *source, AtomicId name,
00703 Variable::Flags flags)
00704 {
00705 SourceVariable *var = getVariable(source->getName(), name);
00706 if (var)
00707 return 0;
00708
00709 var = new SourceVariable(source, name, flags);
00710 variables.push_back(var);
00711 return var;
00712 }
00713
00714 void MVATrainer::fillInputVars(SourceVariableSet &vars,
00715 XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
00716 {
00717 std::vector<SourceVariable*> tmp;
00718 SourceVariable *target = 0;
00719 SourceVariable *weight = 0;
00720
00721 for(DOMNode *node = xml->getFirstChild(); node;
00722 node = node->getNextSibling()) {
00723 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00724 continue;
00725
00726 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
00727 throw cms::Exception("MVATrainer")
00728 << "Invalid input variable node." << std::endl;
00729
00730 DOMElement *elem = static_cast<DOMElement*>(node);
00731
00732 AtomicId source = XMLDocument::readAttribute<std::string>(
00733 elem, "source");
00734 AtomicId name = XMLDocument::readAttribute<std::string>(
00735 elem, "name");
00736
00737 SourceVariable *var = getVariable(source, name);
00738 if (!var)
00739 throw cms::Exception("MVATrainer")
00740 << "Input variable " << (const char*)source
00741 << ":" << (const char*)name
00742 << " not found." << std::endl;
00743
00744 if (XMLDocument::readAttribute<bool>(elem, "target", false)) {
00745 if (target)
00746 throw cms::Exception("MVATrainer")
00747 << "Target variable defined twice"
00748 << std::endl;
00749 target = var;
00750 }
00751 if (XMLDocument::readAttribute<bool>(elem, "weight", false)) {
00752 if (weight)
00753 throw cms::Exception("MVATrainer")
00754 << "Weight variable defined twice"
00755 << std::endl;
00756 weight = var;
00757 }
00758
00759 tmp.push_back(var);
00760 }
00761
00762 if (!weight) {
00763 weight = input->getOutput(kWeightId);
00764 assert(weight);
00765 tmp.insert(tmp.begin() +
00766 (target == input->getOutput(kTargetId)),
00767 1, weight);
00768 }
00769 if (!target) {
00770 target = input->getOutput(kTargetId);
00771 assert(target);
00772 tmp.insert(tmp.begin(), 1, target);
00773 }
00774
00775 unsigned int n = 0;
00776 for(std::vector<SourceVariable*>::const_iterator iter = variables.begin();
00777 iter != variables.end(); iter++) {
00778 std::vector<SourceVariable*>::const_iterator pos =
00779 std::find(tmp.begin(), tmp.end(), *iter);
00780 if (pos == tmp.end())
00781 continue;
00782
00783 SourceVariableSet::Magic magic;
00784 if (*iter == target)
00785 magic = SourceVariableSet::kTarget;
00786 else if (*iter == weight)
00787 magic = SourceVariableSet::kWeight;
00788 else
00789 magic = SourceVariableSet::kRegular;
00790
00791 if (vars.append(*iter, magic, pos - tmp.begin())) {
00792 AtomicId source = (*iter)->getSource()->getName();
00793 AtomicId name = (*iter)->getName();
00794 throw cms::Exception("MVATrainer")
00795 << "Input variable " << (const char*)source
00796 << ":" << (const char*)name
00797 << " defined twice." << std::endl;
00798 }
00799
00800 n++;
00801 }
00802
00803 assert(tmp.size() == n);
00804 }
00805
00806 void MVATrainer::fillOutputVars(SourceVariableSet &vars, Source *source,
00807 XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
00808 {
00809 for(DOMNode *node = xml->getFirstChild(); node;
00810 node = node->getNextSibling()) {
00811 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00812 continue;
00813
00814 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
00815 throw cms::Exception("MVATrainer")
00816 << "Invalid output variable node."
00817 << std::endl;
00818
00819 DOMElement *elem = static_cast<DOMElement*>(node);
00820
00821 AtomicId name = XMLDocument::readAttribute<std::string>(
00822 elem, "name");
00823 if (!name)
00824 throw cms::Exception("MVATrainer")
00825 << "Output variable tag missing name."
00826 << std::endl;
00827 if (isMagic(name))
00828 throw cms::Exception("MVATrainer")
00829 << "Cannot use magic variable names in output."
00830 << std::endl;
00831
00832 Variable::Flags flags = Variable::FLAG_NONE;
00833
00834 if (XMLDocument::readAttribute<bool>(elem, "optional", true))
00835 flags = (PhysicsTools::Variable::Flags)
00836 (flags | Variable::FLAG_OPTIONAL);
00837
00838 if (XMLDocument::readAttribute<bool>(elem, "multiple", true))
00839 flags = (PhysicsTools::Variable::Flags)
00840 (flags | Variable::FLAG_MULTIPLE);
00841
00842 SourceVariable *var = createVariable(source, name, flags);
00843 if (!var || vars.append(var))
00844 throw cms::Exception("MVATrainer")
00845 << "Output variable "
00846 << (const char*)source->getName()
00847 << ":" << (const char*)name
00848 << " defined twice." << std::endl;
00849 }
00850 }
00851
00852 void
00853 MVATrainer::connectProcessors(Calibration::MVAComputer *calib,
00854 const std::vector<CalibratedProcessor> &procs,
00855 bool withTarget) const
00856 {
00857 std::map<SourceVariable*, unsigned int> vars;
00858 unsigned int size = 0;
00859
00860 MVATrainerComputer *trainCalib =
00861 dynamic_cast<MVATrainerComputer*>(calib);
00862
00863 for(unsigned int i = 0;
00864 i < input->getOutputs().size(true); i++) {
00865 if (i < 2 && !withTarget)
00866 continue;
00867
00868 SourceVariable *var = variables[i];
00869 vars[var] = size++;
00870
00871 Calibration::Variable calibVar;
00872 calibVar.name = (const char*)var->getName();
00873 calib->inputSet.push_back(calibVar);
00874 if (trainCalib)
00875 trainCalib->addFlag(var->getFlags());
00876 }
00877
00878 for(std::vector<CalibratedProcessor>::const_iterator iter =
00879 procs.begin(); iter != procs.end(); iter++) {
00880 bool isInterceptor = dynamic_cast<BaseInterceptor*>(
00881 iter->calib) != 0;
00882
00883 BitSet inputSet(size);
00884
00885 unsigned int last = 0;
00886 std::vector<SourceVariable*> inoutVars;
00887 if (iter->processor)
00888 inoutVars = iter->processor->getInputs().get(
00889 isInterceptor);
00890 for(std::vector<SourceVariable*>::const_iterator iter2 =
00891 inoutVars.begin(); iter2 != inoutVars.end(); iter2++) {
00892 std::map<SourceVariable*,
00893 unsigned int>::const_iterator pos =
00894 vars.find(*iter2);
00895
00896 assert(pos != vars.end());
00897
00898 if (pos->second < last)
00899 throw cms::Exception("MVATrainer")
00900 << "Input variables not declared "
00901 "in order of appearance in \""
00902 << (const char*)iter->processor->getName()
00903 << "\"." << std::endl;
00904
00905 inputSet[last = pos->second] = true;
00906 }
00907
00908 assert(!isInterceptor || withTarget);
00909
00910 iter->calib->inputVars = Calibration::convert(inputSet);
00911
00912 calib->output = size;
00913
00914 if (isInterceptor) {
00915 size++;
00916 continue;
00917 }
00918
00919 calib->addProcessor(iter->calib);
00920
00921 inoutVars = iter->processor->getOutputs().get();
00922 for(std::vector<SourceVariable*>::const_iterator iter =
00923 inoutVars.begin(); iter != inoutVars.end(); iter++) {
00924
00925 vars[*iter] = size++;
00926 }
00927 }
00928
00929 if (output->getInputs().size() != 1)
00930 throw cms::Exception("MVATrainer")
00931 << "Exactly one output variable has to be specified."
00932 << std::endl;
00933
00934 SourceVariable *outVar = output->getInputs().get()[0];
00935 std::map<SourceVariable*, unsigned int>::const_iterator pos =
00936 vars.find(outVar);
00937 if (pos != vars.end())
00938 calib->output = pos->second;
00939 }
00940
00941 Calibration::MVAComputer *
00942 MVATrainer::makeTrainCalibration(const AtomicId *compute,
00943 const AtomicId *train) const
00944 {
00945 std::map<AtomicId, TrainInterceptor*> interceptors;
00946 std::vector<MVATrainerComputer::Interceptor> baseInterceptors;
00947 std::vector<CalibratedProcessor> processors;
00948
00949 BaseInterceptor *interceptor = new InitInterceptor;
00950 baseInterceptors.push_back(std::make_pair(0, interceptor));
00951 processors.push_back(CalibratedProcessor(0, interceptor));
00952
00953 for(const AtomicId *iter = train; *iter; iter++) {
00954 TrainProcessor *source;
00955 if (*iter == kOutputId)
00956 source = output;
00957 else {
00958 std::map<AtomicId, Source*>::const_iterator pos =
00959 sources.find(*iter);
00960 assert(pos != sources.end());
00961 source = dynamic_cast<TrainProcessor*>(pos->second);
00962 }
00963 assert(source);
00964
00965 interceptors[*iter] = new TrainInterceptor(source);
00966 }
00967
00968 auto_cleaner<Calibration::VarProcessor> autoClean;
00969
00970 std::set<AtomicId> done;
00971 for(const AtomicId *iter = compute; *iter; iter++) {
00972 if (done.erase(*iter))
00973 continue;
00974
00975 std::map<AtomicId, Source*>::const_iterator pos =
00976 sources.find(*iter);
00977 assert(pos != sources.end());
00978 TrainProcessor *source =
00979 dynamic_cast<TrainProcessor*>(pos->second);
00980 assert(source);
00981 assert(source->isTrained());
00982
00983 Calibration::VarProcessor *proc = source->getCalibration();
00984 if (!proc)
00985 continue;
00986
00987 autoClean.add(proc);
00988 processors.push_back(CalibratedProcessor(source, proc));
00989
00990 Calibration::ProcForeach *looper =
00991 dynamic_cast<Calibration::ProcForeach*>(proc);
00992 if (looper) {
00993 std::vector<AtomicId>::const_iterator pos2 =
00994 std::find(this->processors.begin(),
00995 this->processors.end(), *iter);
00996 assert(pos2 != this->processors.end());
00997 ++pos2;
00998 unsigned int n = 0;
00999 for(int i = 0; i < (int)looper->nProcs; ++i, ++pos2) {
01000 assert(pos2 != this->processors.end());
01001
01002 const AtomicId *iter2 = compute;
01003 while(*iter2) {
01004 if (*iter2 == *pos2)
01005 break;
01006 iter2++;
01007 }
01008
01009 if (*iter2) {
01010 n++;
01011 done.insert(*iter2);
01012 pos = sources.find(*iter2);
01013 assert(pos != sources.end());
01014 TrainProcessor *source =
01015 dynamic_cast<TrainProcessor*>(
01016 pos->second);
01017 assert(source);
01018 assert(source->isTrained());
01019
01020 proc = source->getCalibration();
01021 if (proc) {
01022 autoClean.add(proc);
01023 processors.push_back(
01024 CalibratedProcessor(
01025 source, proc));
01026 }
01027 }
01028
01029 std::map<AtomicId, TrainInterceptor*>::iterator
01030 pos3 = interceptors.find(*pos2);
01031 if (pos3 != interceptors.end()) {
01032 n++;
01033 baseInterceptors.push_back(
01034 std::make_pair(processors.size(),
01035 pos3->second));
01036 processors.push_back(
01037 CalibratedProcessor(
01038 pos3->second->getProcessor(),
01039 pos3->second));
01040 interceptors.erase(pos3);
01041 }
01042 }
01043
01044 looper->nProcs = n;
01045 if (!n) {
01046 baseInterceptors.pop_back();
01047 processors.pop_back();
01048 }
01049 }
01050 }
01051
01052 for(std::map<AtomicId, TrainInterceptor*>::const_iterator iter =
01053 interceptors.begin(); iter != interceptors.end(); ++iter) {
01054
01055 TrainProcessor *proc = iter->second->getProcessor();
01056 baseInterceptors.push_back(std::make_pair(processors.size(),
01057 iter->second));
01058 processors.push_back(CalibratedProcessor(proc, iter->second));
01059 }
01060
01061 std::auto_ptr<Calibration::MVAComputer> calib(
01062 new MVATrainerComputer(baseInterceptors, doAutoSave,
01063 randomSeed, crossValidation));
01064
01065 connectProcessors(calib.get(), processors, true);
01066
01067 return calib.release();
01068 }
01069
01070 void MVATrainer::doneTraining(Calibration::MVAComputer *trainCalibration) const
01071 {
01072 MVATrainerComputer *calib =
01073 dynamic_cast<MVATrainerComputer*>(trainCalibration);
01074
01075 if (!calib)
01076 throw cms::Exception("MVATrainer")
01077 << "Invalid training calibration passed to "
01078 "doneTraining()" << std::endl;
01079
01080 calib->done();
01081 }
01082
01083 std::vector<AtomicId> MVATrainer::findFinalProcessors() const
01084 {
01085 std::set<Source*> toCheck;
01086 toCheck.insert(output);
01087
01088 std::set<Source*> done;
01089 while(!toCheck.empty()) {
01090 Source *source = *toCheck.begin();
01091 toCheck.erase(toCheck.begin());
01092
01093 std::vector<SourceVariable*> inputs = source->inputs.get();
01094 for(std::vector<SourceVariable*>::const_iterator iter =
01095 inputs.begin(); iter != inputs.end(); ++iter) {
01096 source = (*iter)->getSource();
01097 if (done.insert(source).second)
01098 toCheck.insert(source);
01099 }
01100 }
01101
01102 std::vector<AtomicId> result;
01103 for(std::vector<AtomicId>::const_iterator iter = processors.begin();
01104 iter != processors.end(); ++iter) {
01105 std::map<AtomicId, Source*>::const_iterator pos =
01106 sources.find(*iter);
01107 if (pos != sources.end() && done.count(pos->second))
01108 result.push_back(*iter);
01109 }
01110
01111 return result;
01112 }
01113
01114 Calibration::MVAComputer *MVATrainer::getCalibration() const
01115 {
01116 std::vector<CalibratedProcessor> processors;
01117
01118 std::auto_ptr<Calibration::MVAComputer> calib(
01119 new Calibration::MVAComputer);
01120
01121 std::vector<AtomicId> used = findFinalProcessors();
01122 for(std::vector<AtomicId>::const_iterator iter = used.begin();
01123 iter != used.end(); iter++) {
01124 std::map<AtomicId, Source*>::const_iterator pos =
01125 sources.find(*iter);
01126 assert(pos != sources.end());
01127 TrainProcessor *source =
01128 dynamic_cast<TrainProcessor*>(pos->second);
01129 assert(source);
01130 if (!source->isTrained())
01131 return 0;
01132
01133 Calibration::VarProcessor *proc = source->getCalibration();
01134 if (!proc)
01135 continue;
01136
01137 Calibration::ProcForeach *foreach =
01138 dynamic_cast<Calibration::ProcForeach*>(proc);
01139 if (foreach) {
01140 std::vector<AtomicId>::const_iterator begin =
01141 std::find(this->processors.begin(),
01142 this->processors.end(), *iter);
01143 assert(this->processors.end() - begin >
01144 (int)(foreach->nProcs + 1));
01145 ++begin;
01146 std::vector<AtomicId>::const_iterator end =
01147 begin + foreach->nProcs;
01148 foreach->nProcs = 0;
01149 for(std::vector<AtomicId>::const_iterator iter2 =
01150 iter; iter2 != used.end(); ++iter2)
01151 if (std::find(begin, end, *iter2) != end)
01152 foreach->nProcs++;
01153 }
01154
01155 processors.push_back(CalibratedProcessor(source, proc));
01156 }
01157
01158 connectProcessors(calib.get(), processors, false);
01159
01160 return calib.release();
01161 }
01162
01163 void MVATrainer::findUntrainedComputers(std::vector<AtomicId> &compute,
01164 std::vector<AtomicId> &train) const
01165 {
01166 compute.clear();
01167 train.clear();
01168
01169 std::set<Source*> trainedSources;
01170 trainedSources.insert(input);
01171
01172 for(std::vector<AtomicId>::const_iterator iter =
01173 processors.begin(); iter != processors.end(); iter++) {
01174 std::map<AtomicId, Source*>::const_iterator pos =
01175 sources.find(*iter);
01176 assert(pos != sources.end());
01177 TrainProcessor *proc =
01178 dynamic_cast<TrainProcessor*>(pos->second);
01179 assert(proc);
01180
01181 bool trainedDeps = true;
01182 std::vector<SourceVariable*> inputVars =
01183 proc->getInputs().get();
01184 for(std::vector<SourceVariable*>::const_iterator iter2 =
01185 inputVars.begin(); iter2 != inputVars.end(); iter2++) {
01186 if (trainedSources.find((*iter2)->getSource())
01187 == trainedSources.end()) {
01188 trainedDeps = false;
01189 break;
01190 }
01191 }
01192
01193 if (!trainedDeps)
01194 continue;
01195
01196 if (proc->isTrained()) {
01197 trainedSources.insert(proc);
01198 compute.push_back(proc->getName());
01199 } else
01200 train.push_back(proc->getName());
01201 }
01202
01203 if (doMonitoring && !output->isTrained() &&
01204 trainedSources.find(output->getInputs().get()[0]->getSource())
01205 != trainedSources.end())
01206 train.push_back(kOutputId);
01207 }
01208
01209 Calibration::MVAComputer *MVATrainer::getTrainCalibration() const
01210 {
01211 std::vector<AtomicId> compute, train;
01212 findUntrainedComputers(compute, train);
01213
01214 if (train.empty())
01215 return 0;
01216
01217 compute.push_back(0);
01218 train.push_back(0);
01219
01220 return makeTrainCalibration(&compute.front(), &train.front());
01221 }
01222
01223 }