CMS 3D CMS Logo

MVATrainer.cc

Go to the documentation of this file.
00001 #include <assert.h>
00002 #include <functional>
00003 #include <ext/functional>
00004 #include <algorithm>
00005 #include <iostream>
00006 #include <cstdarg>
00007 #include <cstring>
00008 #include <cstdio>
00009 #include <string>
00010 #include <vector>
00011 #include <memory>
00012 #include <map>
00013 #include <set>
00014 
00015 #include <xercesc/dom/DOM.hpp>
00016 
00017 #include <TRandom.h>
00018 
00019 #include "FWCore/Utilities/interface/Exception.h"
00020 #include "FWCore/MessageLogger/interface/MessageLogger.h"
00021 
00022 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00023 #include "PhysicsTools/MVAComputer/interface/BitSet.h"
00024 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00025 #include "PhysicsTools/MVAComputer/interface/Variable.h"
00026 
00027 #include "PhysicsTools/MVATrainer/interface/Interceptor.h"
00028 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00029 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00030 #include "PhysicsTools/MVATrainer/interface/XMLUniStr.h"
00031 #include "PhysicsTools/MVATrainer/interface/Source.h"
00032 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00033 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00034 #include "PhysicsTools/MVATrainer/interface/TrainerMonitoring.h"
00035 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00036 
00037 XERCES_CPP_NAMESPACE_USE
00038 
00039 namespace PhysicsTools {
00040 
00041 namespace { // anonymous
00042         class MVATrainerComputer;
00043 
00044         class BaseInterceptor : public Calibration::Interceptor {
00045             public:
00046                 BaseInterceptor() : calib(0) {}
00047                 virtual ~BaseInterceptor() {}
00048 
00049                 inline void setCalibration(MVATrainerComputer *calib)
00050                 { this->calib = calib; }
00051 
00052                 virtual std::vector<Variable::Flags>
00053                 configure(const MVAComputer *computer, unsigned int n,
00054                           const std::vector<Variable::Flags> &flags) = 0;
00055 
00056                 virtual double
00057                 intercept(const std::vector<double> *values) const = 0;
00058 
00059                 virtual void init() {}
00060                 virtual void finish(bool save) {}
00061 
00062             protected:
00063                 MVATrainerComputer      *calib;
00064         };
00065 
00066         class InitInterceptor : public BaseInterceptor {
00067             public:
00068                 InitInterceptor() {}
00069                 virtual ~InitInterceptor() {}
00070 
00071                 virtual std::vector<Variable::Flags>
00072                 configure(const MVAComputer *computer, unsigned int n,
00073                           const std::vector<Variable::Flags> &flags);
00074 
00075                 virtual double
00076                 intercept(const std::vector<double> *values) const;
00077         };
00078 
00079         class TrainInterceptor : public BaseInterceptor {
00080             public:
00081                 TrainInterceptor(TrainProcessor *proc) : proc(proc) {}
00082                 virtual ~TrainInterceptor() {}
00083 
00084                 inline TrainProcessor *getProcessor() const { return proc; }
00085 
00086                 virtual std::vector<Variable::Flags>
00087                 configure(const MVAComputer *computer, unsigned int n,
00088                           const std::vector<Variable::Flags> &flags);
00089 
00090                 virtual double
00091                 intercept(const std::vector<double> *values) const;
00092 
00093                 virtual void init();
00094                 virtual void finish(bool save);
00095 
00096             private:
00097                 unsigned int                                    targetIdx;
00098                 unsigned int                                    weightIdx;
00099                 mutable std::vector<std::vector<double> >       tmp;
00100                 TrainProcessor                                  *const proc;
00101         };
00102 
00103         class MVATrainerComputer : public TrainMVAComputerCalibration {
00104             public:
00105                 typedef std::pair<unsigned int, BaseInterceptor*> Interceptor;
00106 
00107                 MVATrainerComputer(const std::vector<Interceptor>
00108                                                         &interceptors,
00109                                    bool autoSave, UInt_t seed, double split);
00110 
00111                 virtual ~MVATrainerComputer();
00112 
00113                 virtual std::vector<Calibration::VarProcessor*>
00114                                                         getProcessors() const;
00115                 virtual void initFlags(std::vector<Variable::Flags>
00116                                                         &flags) const;
00117 
00118                 void configured(BaseInterceptor *interceptor) const;
00119                 void next();
00120                 void done();
00121 
00122                 inline void addFlag(Variable::Flags flag)
00123                 { flags.push_back(flag); }
00124 
00125                 inline bool useForTraining() const { return splitResult; }
00126                 inline bool useForTesting() const
00127                 { return split <= 0.0 || !splitResult; }
00128 
00129                 inline bool isConfigured() const
00130                 { return nConfigured == interceptors.size(); }
00131 
00132             private:
00133                 std::vector<Interceptor>        interceptors;
00134                 std::vector<Variable::Flags>    flags;
00135                 mutable unsigned int            nConfigured;
00136                 bool                            doAutoSave;
00137                 TRandom                         random;
00138                 double                          split;
00139                 bool                            splitResult;
00140         };
00141 
00142         // useful litte helpers
00143 
00144         template<typename T>
00145         struct deleter : public std::unary_function<T*, void> {
00146                 inline void operator() (T *ptr) const { delete ptr; }
00147         };
00148 
00149         template<typename T>
00150         struct auto_cleaner {
00151                 inline ~auto_cleaner()
00152                 { std::for_each(clean.begin(), clean.end(), deleter<T>()); }
00153 
00154                 inline void add(T *ptr) { clean.push_back(ptr); }
00155                 std::vector<T*> clean;
00156         };
00157 } // anonymous namespace
00158 
00159 static std::string stdStringVPrintf(const char *format, std::va_list va)
00160 {
00161         unsigned int size = std::min<unsigned int>(128, std::strlen(format));
00162         char *buffer = new char[size];
00163         for(;;) {
00164                 int n = std::vsnprintf(buffer, size, format, va);
00165                 if (n >= 0 && (unsigned int)n < size)
00166                         break;
00167 
00168                 if (n >= 0)
00169                         size = n + 1;
00170                 else
00171                         size *= 2;
00172 
00173                 delete[] buffer;
00174                 buffer = new char[size];
00175         }
00176 
00177         std::string result(buffer);
00178         delete[] buffer;
00179         return result;
00180 }
00181 
00182 static std::string stdStringPrintf(const char *format, ...)
00183 {
00184         std::va_list va;
00185         va_start(va, format);
00186         std::string result = stdStringVPrintf(format, va);
00187         va_end(va);
00188         return result;
00189 }
00190 
00191 // implementation for InitInterceptor
00192 
00193 std::vector<Variable::Flags>
00194 InitInterceptor::configure(const MVAComputer *computer, unsigned int n,
00195                            const std::vector<Variable::Flags> &flags)
00196 {
00197         calib->configured(this);
00198         return std::vector<Variable::Flags>(n, Variable::FLAG_NONE);
00199 }
00200 
00201 double
00202 InitInterceptor::intercept(const std::vector<double> *values) const
00203 {
00204         calib->next();
00205         return 0.0;
00206 }
00207 
00208 // implementation for TrainInterceptor
00209 
00210 std::vector<Variable::Flags>
00211 TrainInterceptor::configure(const MVAComputer *computer, unsigned int n,
00212                             const std::vector<Variable::Flags> &flags)
00213 {
00214         const SourceVariableSet &inputSet = 
00215                 const_cast<const TrainProcessor*>(proc)->getInputs();
00216         SourceVariable *target = inputSet.find(SourceVariableSet::kTarget);
00217         SourceVariable *weight = inputSet.find(SourceVariableSet::kWeight);
00218 
00219         std::vector<SourceVariable*> inputs = inputSet.get(true);
00220 
00221         std::vector<SourceVariable*>::const_iterator pos;
00222         pos = std::find(inputs.begin(), inputs.end(), target);
00223         assert(pos != inputs.end());
00224         targetIdx = pos - inputs.begin();
00225         pos = std::find(inputs.begin(), inputs.end(), weight);
00226         assert(pos != inputs.end());
00227         weightIdx = pos - inputs.begin();
00228 
00229         calib->configured(this);
00230 
00231         std::vector<Variable::Flags> result = flags;
00232         if (targetIdx < weightIdx) {
00233                 result.erase(result.begin() + weightIdx);
00234                 result.erase(result.begin() + targetIdx);
00235         } else {
00236                 result.erase(result.begin() + targetIdx);
00237                 result.erase(result.begin() + weightIdx);
00238         }
00239 
00240         proc->passFlags(result);
00241 
00242         result.clear();
00243         result.resize(n, proc->getDefaultFlags());
00244         result[targetIdx] = Variable::FLAG_NONE;
00245         result[weightIdx] = Variable::FLAG_OPTIONAL;
00246 
00247         if (targetIdx >= 2 || weightIdx >= 2)
00248                 tmp.resize(n - 2);
00249 
00250         return result;
00251 }
00252 
00253 void TrainInterceptor::init()
00254 {
00255         edm::LogInfo("MVATrainer")
00256                 << "TrainProcessor \"" << (const char*)proc->getName()
00257                 << "\" training iteration starting...";
00258 
00259         proc->doTrainBegin();
00260 }
00261 
00262 double
00263 TrainInterceptor::intercept(const std::vector<double> *values) const
00264 {
00265         if (values[targetIdx].size() != 1) {
00266                 if (values[targetIdx].size() == 0)
00267                         throw cms::Exception("MVATrainer")
00268                                 << "Trainer input lacks target variable."
00269                                 << std::endl;
00270                 else
00271                         throw cms::Exception("MVATrainer")
00272                                 << "Multiple targets supplied in input."
00273                                 << std::endl;
00274         }
00275         double target = values[targetIdx].front();
00276 
00277         double weight = 1.0;
00278         if (values[weightIdx].size() > 1)
00279                 throw cms::Exception("MVATrainer")
00280                         << "Multiple weights supplied in input."
00281                         << std::endl;
00282         else if (values[weightIdx].size() == 1)
00283                 weight = values[weightIdx].front();
00284 
00285         if (tmp.empty())
00286                 proc->doTrainData(values + 2, target > 0.5, weight,
00287                                   calib->useForTraining(),
00288                                   calib->useForTesting());
00289         else {
00290                 std::vector<std::vector<double> >::iterator pos = tmp.begin();
00291                 for(unsigned int i = 0; pos != tmp.end(); i++)
00292                         if (i != targetIdx && i != weightIdx)
00293                                 *pos++ = values[i];
00294 
00295                 proc->doTrainData(&tmp.front(), target > 0.5, weight,
00296                                   calib->useForTraining(),
00297                                   calib->useForTesting());
00298         }
00299 
00300         return target;
00301 }
00302 
00303 void TrainInterceptor::finish(bool save)
00304 {
00305         proc->doTrainEnd();
00306 
00307         edm::LogInfo("MVATrainer")
00308                 << "... processor \"" << (const char*)proc->getName()
00309                 << "\" training iteration done.";
00310 
00311         if (proc->isTrained()) {
00312                 edm::LogInfo("MVATrainer")
00313                         << "* Completed training of \""
00314                         << (const char*)proc->getName() << "\".";
00315 
00316                 if (save)
00317                         proc->save();
00318         }
00319 }
00320 
00321 // implementation for MVATrainerComputer
00322 
00323 MVATrainerComputer::MVATrainerComputer(const std::vector<Interceptor>
00324                                                 &interceptors, bool autoSave,
00325                                        UInt_t seed, double split) :
00326         interceptors(interceptors), nConfigured(0), doAutoSave(autoSave),
00327         random(seed), split(split)
00328 {
00329         for(std::vector<Interceptor>::const_iterator iter =
00330                 interceptors.begin(); iter != interceptors.end(); ++iter)
00331                 iter->second->setCalibration(this);
00332 }
00333 
00334 MVATrainerComputer::~MVATrainerComputer()
00335 {
00336         done();
00337         
00338         for(std::vector<Interceptor>::const_iterator iter =
00339                 interceptors.begin(); iter != interceptors.end(); ++iter)
00340                 delete iter->second;
00341 }
00342 
00343 std::vector<Calibration::VarProcessor*>
00344 MVATrainerComputer::getProcessors() const
00345 {
00346         std::vector<Calibration::VarProcessor*> processors =
00347                         Calibration::MVAComputer::getProcessors();
00348 
00349         for(std::vector<Interceptor>::const_iterator iter =
00350                 interceptors.begin(); iter != interceptors.end(); ++iter)
00351 
00352                 processors.insert(processors.begin() + iter->first,
00353                                   1, iter->second);
00354 
00355         return processors;
00356 }
00357 
00358 void MVATrainerComputer::initFlags(std::vector<Variable::Flags> &flags) const
00359 {
00360         assert(flags.size() == this->flags.size());
00361         flags = this->flags;
00362 }
00363 
00364 void MVATrainerComputer::configured(BaseInterceptor *interceptor) const
00365 {
00366         nConfigured++;
00367         if (isConfigured())
00368                 for(std::vector<Interceptor>::const_iterator iter =
00369                                                 interceptors.begin();
00370                     iter != interceptors.end(); ++iter)
00371                         iter->second->init();
00372 }
00373 
00374 void MVATrainerComputer::next()
00375 {
00376         splitResult = random.Uniform(1.0) >= split;
00377 }
00378 
00379 void MVATrainerComputer::done()
00380 {
00381         if (isConfigured()) {
00382                 for(std::vector<Interceptor>::const_iterator iter =
00383                                                 interceptors.begin();
00384                     iter != interceptors.end(); ++iter)
00385                         iter->second->finish(doAutoSave);
00386                 nConfigured = 0;
00387         }
00388 }
00389 
00390 // implementation for MVATrainer
00391 
00392 const AtomicId MVATrainer::kTargetId("__TARGET__");
00393 const AtomicId MVATrainer::kWeightId("__WEIGHT__");
00394 
00395 static const AtomicId kOutputId("__OUTPUT__");
00396 
00397 static bool isMagic(AtomicId id)
00398 {
00399         return id == MVATrainer::kTargetId ||
00400                id == MVATrainer::kWeightId ||
00401                id == kOutputId;
00402 }
00403 
00404 MVATrainer::MVATrainer(const std::string &fileName) :
00405         input(0), output(0), name("MVATrainer"),
00406         doAutoSave(true), doCleanup(false), doMonitoring(false),
00407         randomSeed(65539), crossValidation(0.0)
00408 {
00409         xml = std::auto_ptr<XMLDocument>(new XMLDocument(fileName));
00410 
00411         DOMNode *node = xml->getRootNode();
00412 
00413         if (std::strcmp(XMLSimpleStr(node->getNodeName()), "MVATrainer") != 0)
00414                 throw cms::Exception("MVATrainer")
00415                         << "Invalid XML root node." << std::endl;
00416 
00417         enum State {
00418                 STATE_GENERAL,
00419                 STATE_FIRST,
00420                 STATE_MIDDLE,
00421                 STATE_LAST
00422         } state = STATE_GENERAL;
00423 
00424         for(node = node->getFirstChild();
00425             node; node = node->getNextSibling()) {
00426                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00427                         continue;
00428 
00429                 std::string name = XMLSimpleStr(node->getNodeName());
00430                 DOMElement *elem = static_cast<DOMElement*>(node);
00431 
00432                 switch(state) {
00433                     case STATE_GENERAL: {
00434                         if (name != "general")
00435                                 throw cms::Exception("MVATrainer")
00436                                         << "Expected general config as first "
00437                                            "tag." << std::endl;
00438 
00439                         for(DOMNode *subNode = elem->getFirstChild();
00440                             subNode; subNode = subNode->getNextSibling()) {
00441                                 if (subNode->getNodeType() !=
00442                                     DOMNode::ELEMENT_NODE)
00443                                         continue;
00444 
00445                                 if (std::strcmp(XMLSimpleStr(
00446                                         subNode->getNodeName()), "option") != 0)
00447                                         throw cms::Exception("MVATrainer")
00448                                                 << "Expected option tag."
00449                                                 << std::endl;
00450 
00451                                 elem = static_cast<DOMElement*>(subNode);
00452                                 name = XMLDocument::readAttribute<std::string>(
00453                                                                 elem, "name");
00454                                 std::string content = XMLSimpleStr(
00455                                                 elem->getTextContent());
00456 
00457                                 if (name == "id")
00458                                         this->name = content;
00459                                 else if (name == "trainfiles")
00460                                         trainFileMask = content;
00461                                 else
00462                                         throw cms::Exception("MVATrainer")
00463                                                 << "Unknown option \""
00464                                                 << name << "\"." << std::endl;
00465                         }
00466 
00467                         state = STATE_FIRST;
00468                     }   break;
00469                     case STATE_FIRST: {
00470                         if (name != "input")
00471                                 throw cms::Exception("MVATrainer")
00472                                         << "Expected input config as second "
00473                                            "tag." << std::endl;
00474 
00475                         AtomicId id = XMLDocument::readAttribute<std::string>(
00476                                                                 elem, "id");
00477                         input = new Source(id, true);
00478                         input->getOutputs().append(
00479                                 createVariable(input, kTargetId,
00480                                                Variable::FLAG_NONE),
00481                                 SourceVariableSet::kTarget);
00482                         input->getOutputs().append(
00483                                 createVariable(input, kWeightId,
00484                                                Variable::FLAG_OPTIONAL),
00485                                 SourceVariableSet::kWeight);
00486                         sources.insert(std::make_pair(id, input));
00487                         fillOutputVars(input->getOutputs(), input, elem);
00488 
00489                         state = STATE_MIDDLE;
00490                     }   break;
00491                     case STATE_MIDDLE: {
00492                         if (name == "output") {
00493                                 AtomicId zero;
00494                                 output = new TrainProcessor("output",
00495                                                             &zero, this);
00496                                 fillInputVars(output->getInputs(), elem);
00497                                 state = STATE_LAST;
00498                                 continue;
00499                         } else if (name != "processor")
00500                                 throw cms::Exception("MVATrainer")
00501                                         << "Unexpected tag after input "
00502                                            "config." << std::endl;
00503 
00504                         AtomicId id = XMLDocument::readAttribute<std::string>(
00505                                                                 elem, "id");
00506                         std::string name =
00507                                 XMLDocument::readAttribute<std::string>(
00508                                         elem, "name");
00509 
00510                         makeProcessor(elem, id, name.c_str());
00511                     }   break;
00512                     case STATE_LAST:
00513                         throw cms::Exception("MVATrainer")
00514                                 << "Unexpected tag found after output."
00515                                 << std::endl;
00516                         break;
00517                 }
00518         }
00519 
00520         if (state == STATE_FIRST)
00521                 throw cms::Exception("MVATrainer")
00522                         << "Expected input variable config." << std::endl;
00523         else if (state == STATE_MIDDLE)
00524                 throw cms::Exception("MVATrainer")
00525                         << "Expected output variable config." << std::endl;
00526 
00527         if (trainFileMask.empty())
00528                 trainFileMask = this->name + "_%s%s.%s";
00529 }
00530 
00531 MVATrainer::~MVATrainer()
00532 {
00533         if (monitoring.get())
00534                 monitoring->write();
00535 
00536         for(std::map<AtomicId, Source*>::const_iterator iter = sources.begin();
00537             iter != sources.end(); iter++) {
00538                 TrainProcessor *proc =
00539                                 dynamic_cast<TrainProcessor*>(iter->second);
00540 
00541                 if (proc && doCleanup)
00542                         proc->cleanup();
00543 
00544                 delete iter->second;
00545         }
00546         delete output;
00547         std::for_each(variables.begin(), variables.end(),
00548                       deleter<SourceVariable>());
00549 }
00550 
00551 void MVATrainer::loadState()
00552 {
00553         for(std::vector<AtomicId>::const_iterator iter =
00554                                                 this->processors.begin();
00555             iter != this->processors.end(); iter++) {
00556                 std::map<AtomicId, Source*>::const_iterator pos =
00557                                                         sources.find(*iter);
00558                 assert(pos != sources.end());
00559                 TrainProcessor *source =
00560                                 dynamic_cast<TrainProcessor*>(pos->second);
00561                 assert(source);
00562 
00563                 if (source->load())
00564                         edm::LogInfo("MVATrainer")
00565                                 << source->getId() << " configuration for \""
00566                                 << (const char*)source->getName()
00567                                 << "\" loaded from file.";
00568         }
00569 }
00570 
00571 void MVATrainer::saveState()
00572 {
00573         doCleanup = false;
00574 
00575         for(std::vector<AtomicId>::const_iterator iter =
00576                                                 this->processors.begin();
00577             iter != this->processors.end(); iter++) {
00578                 std::map<AtomicId, Source*>::const_iterator pos =
00579                                                         sources.find(*iter);
00580                 assert(pos != sources.end());
00581                 TrainProcessor *source =
00582                                 dynamic_cast<TrainProcessor*>(pos->second);
00583                 assert(source);
00584 
00585                 if (source->isTrained())
00586                         source->save();
00587         }
00588 }
00589 
00590 void MVATrainer::makeProcessor(DOMElement *elem, AtomicId id, const char *name)
00591 {
00592         DOMElement *xmlInput = 0;
00593         DOMElement *xmlConfig = 0;
00594         DOMElement *xmlOutput = 0;
00595         DOMElement *xmlData = 0;
00596 
00597         static struct NameExpect {
00598                 const char      *tag;
00599                 bool            mandatory;
00600                 DOMElement      **elem;
00601         } const expect[] = {
00602                 { "input",      true,   &xmlInput },
00603                 { "config",     true,   &xmlConfig },
00604                 { "output",     true,   &xmlOutput },
00605                 { "data",       false,  &xmlData },
00606                 { 0, }
00607         };
00608 
00609         const NameExpect *cur = expect;
00610         for(DOMNode *node = elem->getFirstChild();
00611             node; node = node->getNextSibling()) {
00612                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00613                         continue;
00614 
00615                 std::string tag = XMLSimpleStr(node->getNodeName());
00616                 DOMElement *elem = static_cast<DOMElement*>(node);
00617 
00618                 if (!cur->tag)
00619                         throw cms::Exception("MVATrainer")
00620                                 << "Superfluous tag " << tag
00621                                 << "encountered in processor." << std::endl;
00622                 else if (tag != cur->tag && cur->mandatory)
00623                         throw cms::Exception("MVATrainer")
00624                                 << "Expected tag " << cur->tag << ", got "
00625                                 << tag << " instead in processor."
00626                                 << std::endl;
00627                 else if (tag != cur->tag) {
00628                         cur++;
00629                         continue;
00630                 }
00631                 *(cur++)->elem = elem;
00632         }
00633 
00634         while(cur->tag && !cur->mandatory)
00635                 cur++;
00636         if (cur->tag)
00637                 throw cms::Exception("MVATrainer")
00638                         << "Unexpected end of processor configuration, "
00639                         << "expected tag " << cur->tag << "." << std::endl;
00640 
00641         std::auto_ptr<TrainProcessor> proc(
00642                                 TrainProcessor::create(name, &id, this));
00643         if (!proc.get())
00644                 throw cms::Exception("MVATrainer")
00645                         << "Variable processor trainer " << name
00646                         << " could not be instantiated. Most likely because"
00647                            " the trainer plugin for \"" << name << "\""
00648                            " \" does not exist." << std::endl;
00649 
00650         if (sources.find(id) != sources.end())
00651                 throw cms::Exception("MVATrainer")
00652                         << "Duplicate variable processor id "
00653                         << (const char*)id << "."
00654                         << std::endl;
00655 
00656         fillInputVars(proc->getInputs(), xmlInput);
00657         fillOutputVars(proc->getOutputs(), proc.get(), xmlOutput);
00658 
00659         edm::LogInfo("MVATrainer")
00660                 << "Configuring " << (const char*)proc->getId()
00661                 << " \"" << (const char*)proc->getName() << "\".";
00662         proc->configure(xmlConfig);
00663 
00664         sources.insert(std::make_pair(id, proc.release()));
00665         processors.push_back(id);
00666 }
00667 
00668 std::string MVATrainer::trainFileName(const TrainProcessor *proc,
00669                                       const std::string &ext,
00670                                       const std::string &arg) const
00671 {
00672         std::string arg_ = arg.size() > 0 ? ("_" + arg) : "";
00673         return stdStringPrintf(trainFileMask.c_str(),
00674                                (const char*)proc->getName(),
00675                                arg_.c_str(), ext.c_str());
00676 }
00677 
00678 TrainerMonitoring::Module *MVATrainer::bookMonitor(const std::string &name)
00679 {
00680         if (!doMonitoring)
00681                 return 0;
00682 
00683         if (!monitoring.get()) {
00684                 std::string fileName = 
00685                         stdStringPrintf(trainFileMask.c_str(),
00686                                         "monitoring", "", "root");
00687                 monitoring.reset(new TrainerMonitoring(fileName));
00688         }
00689 
00690         return monitoring->book(name);
00691 }
00692 
00693 SourceVariable *MVATrainer::getVariable(AtomicId source, AtomicId name) const
00694 {
00695         std::map<AtomicId, Source*>::const_iterator pos = sources.find(source);
00696         if (pos == sources.end())
00697                 return 0;
00698 
00699         return pos->second->getOutput(name);
00700 }
00701 
00702 SourceVariable *MVATrainer::createVariable(Source *source, AtomicId name,
00703                                            Variable::Flags flags)
00704 {
00705         SourceVariable *var = getVariable(source->getName(), name);
00706         if (var)
00707                 return 0;
00708 
00709         var = new SourceVariable(source, name, flags);
00710         variables.push_back(var);
00711         return var;
00712 }
00713 
00714 void MVATrainer::fillInputVars(SourceVariableSet &vars,
00715                                XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
00716 {
00717         std::vector<SourceVariable*> tmp;
00718         SourceVariable *target = 0;
00719         SourceVariable *weight = 0;
00720 
00721         for(DOMNode *node = xml->getFirstChild(); node;
00722             node = node->getNextSibling()) {
00723                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00724                         continue;
00725 
00726                 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
00727                         throw cms::Exception("MVATrainer")
00728                                 << "Invalid input variable node." << std::endl;
00729 
00730                 DOMElement *elem = static_cast<DOMElement*>(node);
00731 
00732                 AtomicId source = XMLDocument::readAttribute<std::string>(
00733                                                         elem, "source");
00734                 AtomicId name = XMLDocument::readAttribute<std::string>(
00735                                                         elem, "name");
00736 
00737                 SourceVariable *var = getVariable(source, name);
00738                 if (!var)
00739                         throw cms::Exception("MVATrainer")
00740                                 << "Input variable " << (const char*)source
00741                                 << ":" << (const char*)name
00742                                 << " not found." << std::endl;
00743 
00744                 if (XMLDocument::readAttribute<bool>(elem, "target", false)) {
00745                         if (target)
00746                                 throw cms::Exception("MVATrainer")
00747                                         << "Target variable defined twice"
00748                                         << std::endl;
00749                         target = var;
00750                 }
00751                 if (XMLDocument::readAttribute<bool>(elem, "weight", false)) {
00752                         if (weight)
00753                                 throw cms::Exception("MVATrainer")
00754                                         << "Weight variable defined twice"
00755                                         << std::endl;
00756                         weight = var;
00757                 }
00758 
00759                 tmp.push_back(var);
00760         }
00761 
00762         if (!weight) {
00763                 weight = input->getOutput(kWeightId);
00764                 assert(weight);
00765                 tmp.insert(tmp.begin() +
00766                                 (target == input->getOutput(kTargetId)),
00767                            1, weight);
00768         }
00769         if (!target) {
00770                 target = input->getOutput(kTargetId);
00771                 assert(target);
00772                 tmp.insert(tmp.begin(), 1, target);
00773         }
00774 
00775         unsigned int n = 0;
00776         for(std::vector<SourceVariable*>::const_iterator iter = variables.begin();
00777             iter != variables.end(); iter++) {
00778                 std::vector<SourceVariable*>::const_iterator pos =
00779                         std::find(tmp.begin(), tmp.end(), *iter);
00780                 if (pos == tmp.end())
00781                         continue;
00782 
00783                 SourceVariableSet::Magic magic;
00784                 if (*iter == target)
00785                         magic = SourceVariableSet::kTarget;
00786                 else if (*iter == weight)
00787                         magic = SourceVariableSet::kWeight;
00788                 else
00789                         magic = SourceVariableSet::kRegular;
00790 
00791                 if (vars.append(*iter, magic, pos - tmp.begin())) {
00792                         AtomicId source = (*iter)->getSource()->getName();
00793                         AtomicId name = (*iter)->getName();
00794                         throw cms::Exception("MVATrainer")
00795                                 << "Input variable " << (const char*)source
00796                                 << ":" << (const char*)name
00797                                 << " defined twice." << std::endl;
00798                 }
00799 
00800                 n++;
00801         }
00802 
00803         assert(tmp.size() == n);
00804 }
00805 
00806 void MVATrainer::fillOutputVars(SourceVariableSet &vars, Source *source,
00807                                 XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
00808 {
00809         for(DOMNode *node = xml->getFirstChild(); node;
00810             node = node->getNextSibling()) {
00811                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00812                         continue;
00813 
00814                 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
00815                         throw cms::Exception("MVATrainer")
00816                                 << "Invalid output variable node."
00817                                 << std::endl;
00818 
00819                 DOMElement *elem = static_cast<DOMElement*>(node);
00820 
00821                 AtomicId name = XMLDocument::readAttribute<std::string>(
00822                                                         elem, "name");
00823                 if (!name)
00824                         throw cms::Exception("MVATrainer")
00825                                 << "Output variable tag missing name."
00826                                 << std::endl;
00827                 if (isMagic(name))
00828                         throw cms::Exception("MVATrainer")
00829                                 << "Cannot use magic variable names in output."
00830                                 << std::endl;
00831 
00832                 Variable::Flags flags = Variable::FLAG_NONE;
00833 
00834                 if (XMLDocument::readAttribute<bool>(elem, "optional", true))
00835                         flags = (PhysicsTools::Variable::Flags)
00836                                 (flags | Variable::FLAG_OPTIONAL);
00837 
00838                 if (XMLDocument::readAttribute<bool>(elem, "multiple", true))
00839                         flags = (PhysicsTools::Variable::Flags)
00840                                 (flags | Variable::FLAG_MULTIPLE);
00841 
00842                 SourceVariable *var = createVariable(source, name, flags);
00843                 if (!var || vars.append(var))
00844                         throw cms::Exception("MVATrainer")
00845                                 << "Output variable "
00846                                 << (const char*)source->getName()
00847                                 << ":" << (const char*)name
00848                                 << " defined twice." << std::endl;
00849         }
00850 }
00851 
00852 void
00853 MVATrainer::connectProcessors(Calibration::MVAComputer *calib,
00854                               const std::vector<CalibratedProcessor> &procs,
00855                               bool withTarget) const
00856 {
00857         std::map<SourceVariable*, unsigned int> vars;
00858         unsigned int size = 0;
00859 
00860         MVATrainerComputer *trainCalib =
00861                         dynamic_cast<MVATrainerComputer*>(calib);
00862 
00863         for(unsigned int i = 0;
00864             i < input->getOutputs().size(true); i++) {
00865                 if (i < 2 && !withTarget)
00866                         continue;
00867 
00868                 SourceVariable *var = variables[i];
00869                 vars[var] = size++;
00870 
00871                 Calibration::Variable calibVar;
00872                 calibVar.name = (const char*)var->getName();
00873                 calib->inputSet.push_back(calibVar);
00874                 if (trainCalib)
00875                         trainCalib->addFlag(var->getFlags());
00876         }
00877 
00878         for(std::vector<CalibratedProcessor>::const_iterator iter =
00879                                 procs.begin(); iter != procs.end(); iter++) {
00880                 bool isInterceptor = dynamic_cast<BaseInterceptor*>(
00881                                                         iter->calib) != 0;
00882 
00883                 BitSet inputSet(size);
00884 
00885                 unsigned int last = 0;
00886                 std::vector<SourceVariable*> inoutVars;
00887                 if (iter->processor)
00888                         inoutVars = iter->processor->getInputs().get(
00889                                                                 isInterceptor);
00890                 for(std::vector<SourceVariable*>::const_iterator iter2 =
00891                         inoutVars.begin(); iter2 != inoutVars.end(); iter2++) {
00892                         std::map<SourceVariable*,
00893                                  unsigned int>::const_iterator pos =
00894                                                         vars.find(*iter2);
00895 
00896                         assert(pos != vars.end());
00897 
00898                         if (pos->second < last)
00899                                 throw cms::Exception("MVATrainer")
00900                                         << "Input variables not declared "
00901                                            "in order of appearance in \""
00902                                         << (const char*)iter->processor->getName()
00903                                         << "\"." << std::endl;
00904 
00905                         inputSet[last = pos->second] = true;
00906                 }
00907 
00908                 assert(!isInterceptor || withTarget);
00909 
00910                 iter->calib->inputVars = Calibration::convert(inputSet);
00911 
00912                 calib->output = size;
00913 
00914                 if (isInterceptor) {
00915                         size++;
00916                         continue;
00917                 }
00918 
00919                 calib->addProcessor(iter->calib);
00920 
00921                 inoutVars = iter->processor->getOutputs().get();
00922                 for(std::vector<SourceVariable*>::const_iterator iter =
00923                         inoutVars.begin(); iter != inoutVars.end(); iter++) {
00924 
00925                         vars[*iter] = size++;
00926                 }
00927         }
00928 
00929         if (output->getInputs().size() != 1)
00930                 throw cms::Exception("MVATrainer")
00931                         << "Exactly one output variable has to be specified."
00932                         << std::endl;
00933 
00934         SourceVariable *outVar = output->getInputs().get()[0];
00935         std::map<SourceVariable*, unsigned int>::const_iterator pos =
00936                                                         vars.find(outVar);
00937         if (pos != vars.end())
00938                 calib->output = pos->second;
00939 }
00940 
00941 Calibration::MVAComputer *
00942 MVATrainer::makeTrainCalibration(const AtomicId *compute,
00943                                  const AtomicId *train) const
00944 {
00945         std::map<AtomicId, TrainInterceptor*> interceptors;
00946         std::vector<MVATrainerComputer::Interceptor> baseInterceptors;
00947         std::vector<CalibratedProcessor> processors;
00948 
00949         BaseInterceptor *interceptor = new InitInterceptor;
00950         baseInterceptors.push_back(std::make_pair(0, interceptor));
00951         processors.push_back(CalibratedProcessor(0, interceptor));
00952 
00953         for(const AtomicId *iter = train; *iter; iter++) {
00954                 TrainProcessor *source;
00955                 if (*iter == kOutputId)
00956                         source = output;
00957                 else {
00958                         std::map<AtomicId, Source*>::const_iterator pos =
00959                                                         sources.find(*iter);
00960                         assert(pos != sources.end());
00961                         source = dynamic_cast<TrainProcessor*>(pos->second);
00962                 }
00963                 assert(source);
00964 
00965                 interceptors[*iter] = new TrainInterceptor(source);
00966         }
00967 
00968         auto_cleaner<Calibration::VarProcessor> autoClean;
00969 
00970         std::set<AtomicId> done;
00971         for(const AtomicId *iter = compute; *iter; iter++) {
00972                 if (done.erase(*iter))
00973                         continue;
00974 
00975                 std::map<AtomicId, Source*>::const_iterator pos =
00976                                                         sources.find(*iter);
00977                 assert(pos != sources.end());
00978                 TrainProcessor *source =
00979                                 dynamic_cast<TrainProcessor*>(pos->second);
00980                 assert(source);
00981                 assert(source->isTrained());
00982 
00983                 Calibration::VarProcessor *proc = source->getCalibration();
00984                 if (!proc)
00985                         continue;
00986 
00987                 autoClean.add(proc);
00988                 processors.push_back(CalibratedProcessor(source, proc));
00989 
00990                 Calibration::ProcForeach *looper =
00991                                 dynamic_cast<Calibration::ProcForeach*>(proc);
00992                 if (looper) {
00993                         std::vector<AtomicId>::const_iterator pos2 =
00994                                 std::find(this->processors.begin(),
00995                                           this->processors.end(), *iter);
00996                         assert(pos2 != this->processors.end());
00997                         ++pos2;
00998                         unsigned int n = 0;
00999                         for(int i = 0; i < (int)looper->nProcs; ++i, ++pos2) {
01000                                 assert(pos2 != this->processors.end());
01001 
01002                                 const AtomicId *iter2 = compute;
01003                                 while(*iter2) {
01004                                         if (*iter2 == *pos2)
01005                                                 break;
01006                                         iter2++;
01007                                 }
01008 
01009                                 if (*iter2) {
01010                                         n++;
01011                                         done.insert(*iter2);
01012                                         pos = sources.find(*iter2);
01013                                         assert(pos != sources.end());
01014                                         TrainProcessor *source =
01015                                                 dynamic_cast<TrainProcessor*>(
01016                                                                 pos->second);
01017                                         assert(source);
01018                                         assert(source->isTrained());
01019 
01020                                         proc = source->getCalibration();
01021                                         if (proc) {
01022                                                 autoClean.add(proc);
01023                                                 processors.push_back(
01024                                                         CalibratedProcessor(
01025                                                                 source, proc));
01026                                         }
01027                                 }
01028 
01029                                 std::map<AtomicId, TrainInterceptor*>::iterator
01030                                                 pos3 = interceptors.find(*pos2);
01031                                 if (pos3 != interceptors.end()) {
01032                                         n++;
01033                                         baseInterceptors.push_back(
01034                                                 std::make_pair(processors.size(),
01035                                                                pos3->second));
01036                                         processors.push_back(
01037                                                 CalibratedProcessor(
01038                                                         pos3->second->getProcessor(),
01039                                                         pos3->second));
01040                                         interceptors.erase(pos3);
01041                                 }
01042                         }
01043 
01044                         looper->nProcs = n;
01045                         if (!n) {
01046                                 baseInterceptors.pop_back();
01047                                 processors.pop_back();
01048                         }
01049                 }
01050         }
01051 
01052         for(std::map<AtomicId, TrainInterceptor*>::const_iterator iter =
01053                 interceptors.begin(); iter != interceptors.end(); ++iter) {
01054 
01055                 TrainProcessor *proc = iter->second->getProcessor();
01056                 baseInterceptors.push_back(std::make_pair(processors.size(),
01057                                                           iter->second));
01058                 processors.push_back(CalibratedProcessor(proc, iter->second));
01059         }
01060 
01061         std::auto_ptr<Calibration::MVAComputer> calib(
01062                 new MVATrainerComputer(baseInterceptors, doAutoSave,
01063                                        randomSeed, crossValidation));
01064 
01065         connectProcessors(calib.get(), processors, true);
01066 
01067         return calib.release();
01068 }
01069 
01070 void MVATrainer::doneTraining(Calibration::MVAComputer *trainCalibration) const
01071 {
01072         MVATrainerComputer *calib =
01073                         dynamic_cast<MVATrainerComputer*>(trainCalibration);
01074 
01075         if (!calib)
01076                 throw cms::Exception("MVATrainer")
01077                         << "Invalid training calibration passed to "
01078                            "doneTraining()" << std::endl;
01079 
01080         calib->done();
01081 }
01082 
01083 std::vector<AtomicId> MVATrainer::findFinalProcessors() const
01084 {
01085         std::set<Source*> toCheck;
01086         toCheck.insert(output);
01087 
01088         std::set<Source*> done;
01089         while(!toCheck.empty()) {
01090                 Source *source = *toCheck.begin();
01091                 toCheck.erase(toCheck.begin());
01092 
01093                 std::vector<SourceVariable*> inputs = source->inputs.get();
01094                 for(std::vector<SourceVariable*>::const_iterator iter =
01095                                 inputs.begin(); iter != inputs.end(); ++iter) {
01096                         source = (*iter)->getSource();
01097                         if (done.insert(source).second)
01098                                 toCheck.insert(source);
01099                 }
01100         }
01101 
01102         std::vector<AtomicId> result;
01103         for(std::vector<AtomicId>::const_iterator iter = processors.begin();
01104             iter != processors.end(); ++iter) {
01105                 std::map<AtomicId, Source*>::const_iterator pos =
01106                                                         sources.find(*iter);
01107                 if (pos != sources.end() && done.count(pos->second))
01108                         result.push_back(*iter);
01109         }
01110 
01111         return result;
01112 }
01113 
01114 Calibration::MVAComputer *MVATrainer::getCalibration() const
01115 {
01116         std::vector<CalibratedProcessor> processors;
01117 
01118         std::auto_ptr<Calibration::MVAComputer> calib(
01119                                                 new Calibration::MVAComputer);
01120 
01121         std::vector<AtomicId> used = findFinalProcessors();
01122         for(std::vector<AtomicId>::const_iterator iter = used.begin();
01123             iter != used.end(); iter++) {
01124                 std::map<AtomicId, Source*>::const_iterator pos =
01125                                                         sources.find(*iter);
01126                 assert(pos != sources.end());
01127                 TrainProcessor *source =
01128                                 dynamic_cast<TrainProcessor*>(pos->second);
01129                 assert(source);
01130                 if (!source->isTrained())
01131                         return 0;
01132 
01133                 Calibration::VarProcessor *proc = source->getCalibration();
01134                 if (!proc)
01135                         continue;
01136 
01137                 Calibration::ProcForeach *foreach =
01138                                 dynamic_cast<Calibration::ProcForeach*>(proc);
01139                 if (foreach) {
01140                         std::vector<AtomicId>::const_iterator begin =
01141                                 std::find(this->processors.begin(),
01142                                           this->processors.end(), *iter);
01143                         assert(this->processors.end() - begin >
01144                                (int)(foreach->nProcs + 1));
01145                         ++begin;
01146                         std::vector<AtomicId>::const_iterator end =
01147                                                 begin + foreach->nProcs;
01148                         foreach->nProcs = 0;
01149                         for(std::vector<AtomicId>::const_iterator iter2 =
01150                                         iter; iter2 != used.end(); ++iter2)
01151                                 if (std::find(begin, end, *iter2) != end)
01152                                         foreach->nProcs++;
01153                 }
01154 
01155                 processors.push_back(CalibratedProcessor(source, proc));
01156         }
01157 
01158         connectProcessors(calib.get(), processors, false);
01159 
01160         return calib.release();
01161 }
01162 
01163 void MVATrainer::findUntrainedComputers(std::vector<AtomicId> &compute,
01164                                         std::vector<AtomicId> &train) const
01165 {
01166         compute.clear();
01167         train.clear();
01168 
01169         std::set<Source*> trainedSources;
01170         trainedSources.insert(input);
01171 
01172         for(std::vector<AtomicId>::const_iterator iter =
01173                 processors.begin(); iter != processors.end(); iter++) {
01174                 std::map<AtomicId, Source*>::const_iterator pos =
01175                                                         sources.find(*iter);
01176                 assert(pos != sources.end());
01177                 TrainProcessor *proc =
01178                                 dynamic_cast<TrainProcessor*>(pos->second);
01179                 assert(proc);
01180 
01181                 bool trainedDeps = true;
01182                 std::vector<SourceVariable*> inputVars =
01183                                         proc->getInputs().get();
01184                 for(std::vector<SourceVariable*>::const_iterator iter2 =
01185                         inputVars.begin(); iter2 != inputVars.end(); iter2++) {
01186                         if (trainedSources.find((*iter2)->getSource())
01187                             == trainedSources.end()) {
01188                                 trainedDeps = false;
01189                                 break;
01190                         }
01191                 }
01192 
01193                 if (!trainedDeps)
01194                         continue;
01195 
01196                 if (proc->isTrained()) {
01197                         trainedSources.insert(proc);
01198                         compute.push_back(proc->getName());
01199                 } else
01200                         train.push_back(proc->getName());
01201         }
01202 
01203         if (doMonitoring && !output->isTrained() &&
01204             trainedSources.find(output->getInputs().get()[0]->getSource())
01205                                                 != trainedSources.end())
01206                 train.push_back(kOutputId);
01207 }
01208 
01209 Calibration::MVAComputer *MVATrainer::getTrainCalibration() const
01210 {
01211         std::vector<AtomicId> compute, train;
01212         findUntrainedComputers(compute, train);
01213 
01214         if (train.empty())
01215                 return 0;
01216 
01217         compute.push_back(0);
01218         train.push_back(0);
01219 
01220         return makeTrainCalibration(&compute.front(), &train.front());
01221 }
01222 
01223 } // namespace PhysicsTools

Generated on Tue Jun 9 17:41:32 2009 for CMSSW by  doxygen 1.5.4