CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_4_4_5_patch3/src/PhysicsTools/MVATrainer/src/MVATrainer.cc

Go to the documentation of this file.
00001 #include <assert.h>
00002 #include <functional>
00003 #include <ext/functional>
00004 #include <algorithm>
00005 #include <iostream>
00006 #include <cstdarg>
00007 #include <cstring>
00008 #include <cstdio>
00009 #include <string>
00010 #include <vector>
00011 #include <memory>
00012 #include <map>
00013 #include <set>
00014 
00015 #include <xercesc/dom/DOM.hpp>
00016 
00017 #include <TRandom.h>
00018 
00019 #include "FWCore/Utilities/interface/Exception.h"
00020 #include "FWCore/ParameterSet/interface/FileInPath.h"
00021 #include "FWCore/MessageLogger/interface/MessageLogger.h"
00022 
00023 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00024 #include "PhysicsTools/MVAComputer/interface/BitSet.h"
00025 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00026 #include "PhysicsTools/MVAComputer/interface/Variable.h"
00027 
00028 #include "PhysicsTools/MVATrainer/interface/Interceptor.h"
00029 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00030 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00031 #include "PhysicsTools/MVATrainer/interface/XMLUniStr.h"
00032 #include "PhysicsTools/MVATrainer/interface/Source.h"
00033 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00034 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00035 #include "PhysicsTools/MVATrainer/interface/TrainerMonitoring.h"
00036 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00037 
00038 XERCES_CPP_NAMESPACE_USE
00039 
00040 namespace PhysicsTools {
00041 
00042 namespace { // anonymous
00043         class MVATrainerComputer;
00044 
00045         class BaseInterceptor : public Calibration::Interceptor {
00046             public:
00047                 BaseInterceptor() : calib(0) {}
00048                 virtual ~BaseInterceptor() {}
00049 
00050                 inline void setCalibration(MVATrainerComputer *calib)
00051                 { this->calib = calib; }
00052 
00053                 virtual std::vector<Variable::Flags>
00054                 configure(const MVAComputer *computer, unsigned int n,
00055                           const std::vector<Variable::Flags> &flags) = 0;
00056 
00057                 virtual double
00058                 intercept(const std::vector<double> *values) const = 0;
00059 
00060                 virtual void init() {}
00061                 virtual void finish(bool save) {}
00062 
00063             protected:
00064                 MVATrainerComputer      *calib;
00065         };
00066 
00067         class InitInterceptor : public BaseInterceptor {
00068             public:
00069                 InitInterceptor() {}
00070                 virtual ~InitInterceptor() {}
00071 
00072                 virtual std::vector<Variable::Flags>
00073                 configure(const MVAComputer *computer, unsigned int n,
00074                           const std::vector<Variable::Flags> &flags);
00075 
00076                 virtual double
00077                 intercept(const std::vector<double> *values) const;
00078         };
00079 
00080         class TrainInterceptor : public BaseInterceptor {
00081             public:
00082                 TrainInterceptor(TrainProcessor *proc) : proc(proc) {}
00083                 virtual ~TrainInterceptor() {}
00084 
00085                 inline TrainProcessor *getProcessor() const { return proc; }
00086 
00087                 virtual std::vector<Variable::Flags>
00088                 configure(const MVAComputer *computer, unsigned int n,
00089                           const std::vector<Variable::Flags> &flags);
00090 
00091                 virtual double
00092                 intercept(const std::vector<double> *values) const;
00093 
00094                 virtual void init();
00095                 virtual void finish(bool save);
00096 
00097             private:
00098                 unsigned int                                    targetIdx;
00099                 unsigned int                                    weightIdx;
00100                 mutable std::vector<std::vector<double> >       tmp;
00101                 TrainProcessor                                  *const proc;
00102         };
00103 
00104         class MVATrainerComputer : public TrainMVAComputerCalibration {
00105             public:
00106                 typedef std::pair<unsigned int, BaseInterceptor*> Interceptor;
00107 
00108                 MVATrainerComputer(const std::vector<Interceptor>
00109                                                         &interceptors,
00110                                    bool autoSave, UInt_t seed, double split);
00111 
00112                 virtual ~MVATrainerComputer();
00113 
00114                 virtual std::vector<Calibration::VarProcessor*>
00115                                                         getProcessors() const;
00116                 virtual void initFlags(std::vector<Variable::Flags>
00117                                                         &flags) const;
00118 
00119                 void configured(BaseInterceptor *interceptor) const;
00120                 void next();
00121                 void done();
00122 
00123                 inline void addFlag(Variable::Flags flag)
00124                 { flags.push_back(flag); }
00125 
00126                 inline bool useForTraining() const { return splitResult; }
00127                 inline bool useForTesting() const
00128                 { return split <= 0.0 || !splitResult; }
00129 
00130                 inline bool isConfigured() const
00131                 { return nConfigured == interceptors.size(); }
00132 
00133             private:
00134                 std::vector<Interceptor>        interceptors;
00135                 std::vector<Variable::Flags>    flags;
00136                 mutable unsigned int            nConfigured;
00137                 bool                            doAutoSave;
00138                 TRandom                         random;
00139                 double                          split;
00140                 bool                            splitResult;
00141         };
00142 
00143         // useful litte helpers
00144 
00145         template<typename T>
00146         struct deleter : public std::unary_function<T*, void> {
00147                 inline void operator() (T *ptr) const { delete ptr; }
00148         };
00149 
00150         template<typename T>
00151         struct auto_cleaner {
00152                 inline ~auto_cleaner()
00153                 { std::for_each(clean.begin(), clean.end(), deleter<T>()); }
00154 
00155                 inline void add(T *ptr) { clean.push_back(ptr); }
00156                 std::vector<T*> clean;
00157         };
00158 } // anonymous namespace
00159 
00160 static std::string stdStringVPrintf(const char *format, std::va_list va)
00161 {
00162         unsigned int size = std::min<unsigned int>(128, std::strlen(format));
00163         char *buffer = new char[size];
00164         for(;;) {
00165                 int n = std::vsnprintf(buffer, size, format, va);
00166                 if (n >= 0 && (unsigned int)n < size)
00167                         break;
00168 
00169                 if (n >= 0)
00170                         size = n + 1;
00171                 else
00172                         size *= 2;
00173 
00174                 delete[] buffer;
00175                 buffer = new char[size];
00176         }
00177 
00178         std::string result(buffer);
00179         delete[] buffer;
00180         return result;
00181 }
00182 
00183 static std::string stdStringPrintf(const char *format, ...)
00184 {
00185         std::va_list va;
00186         va_start(va, format);
00187         std::string result = stdStringVPrintf(format, va);
00188         va_end(va);
00189         return result;
00190 }
00191 
00192 // implementation for InitInterceptor
00193 
00194 std::vector<Variable::Flags>
00195 InitInterceptor::configure(const MVAComputer *computer, unsigned int n,
00196                            const std::vector<Variable::Flags> &flags)
00197 {
00198         calib->configured(this);
00199         return std::vector<Variable::Flags>(n, Variable::FLAG_ALL);
00200 }
00201 
00202 double
00203 InitInterceptor::intercept(const std::vector<double> *values) const
00204 {
00205         calib->next();
00206         return 0.0;
00207 }
00208 
00209 // implementation for TrainInterceptor
00210 
00211 std::vector<Variable::Flags>
00212 TrainInterceptor::configure(const MVAComputer *computer, unsigned int n,
00213                             const std::vector<Variable::Flags> &flags)
00214 {
00215         const SourceVariableSet &inputSet = 
00216                 const_cast<const TrainProcessor*>(proc)->getInputs();
00217         SourceVariable *target = inputSet.find(SourceVariableSet::kTarget);
00218         SourceVariable *weight = inputSet.find(SourceVariableSet::kWeight);
00219 
00220         std::vector<SourceVariable*> inputs = inputSet.get(true);
00221 
00222         std::vector<SourceVariable*>::const_iterator pos;
00223         pos = std::find(inputs.begin(), inputs.end(), target);
00224         assert(pos != inputs.end());
00225         targetIdx = pos - inputs.begin();
00226         pos = std::find(inputs.begin(), inputs.end(), weight);
00227         assert(pos != inputs.end());
00228         weightIdx = pos - inputs.begin();
00229 
00230         calib->configured(this);
00231 
00232         std::vector<Variable::Flags> result = flags;
00233         if (targetIdx < weightIdx) {
00234                 result.erase(result.begin() + weightIdx);
00235                 result.erase(result.begin() + targetIdx);
00236         } else {
00237                 result.erase(result.begin() + targetIdx);
00238                 result.erase(result.begin() + weightIdx);
00239         }
00240 
00241         proc->passFlags(result);
00242 
00243         result.clear();
00244         result.resize(n, proc->getDefaultFlags());
00245         result[targetIdx] = Variable::FLAG_NONE;
00246         result[weightIdx] = Variable::FLAG_OPTIONAL;
00247 
00248         if (targetIdx >= 2 || weightIdx >= 2)
00249                 tmp.resize(n - 2);
00250 
00251         return result;
00252 }
00253 
00254 void TrainInterceptor::init()
00255 {
00256         edm::LogInfo("MVATrainer")
00257                 << "TrainProcessor \"" << (const char*)proc->getName()
00258                 << "\" training iteration starting...";
00259 
00260         proc->doTrainBegin();
00261 }
00262 
00263 double
00264 TrainInterceptor::intercept(const std::vector<double> *values) const
00265 {
00266         if (values[targetIdx].size() != 1) {
00267                 if (values[targetIdx].size() == 0)
00268                         throw cms::Exception("MVATrainer")
00269                                 << "Trainer input lacks target variable."
00270                                 << std::endl;
00271                 else
00272                         throw cms::Exception("MVATrainer")
00273                                 << "Multiple targets supplied in input."
00274                                 << std::endl;
00275         }
00276         double target = values[targetIdx].front();
00277 
00278         double weight = 1.0;
00279         if (values[weightIdx].size() > 1)
00280                 throw cms::Exception("MVATrainer")
00281                         << "Multiple weights supplied in input."
00282                         << std::endl;
00283         else if (values[weightIdx].size() == 1)
00284                 weight = values[weightIdx].front();
00285 
00286         if (tmp.empty())
00287                 proc->doTrainData(values + 2, target > 0.5, weight,
00288                                   calib->useForTraining(),
00289                                   calib->useForTesting());
00290         else {
00291                 std::vector<std::vector<double> >::iterator pos = tmp.begin();
00292                 for(unsigned int i = 0; pos != tmp.end(); i++)
00293                         if (i != targetIdx && i != weightIdx)
00294                                 *pos++ = values[i];
00295 
00296                 proc->doTrainData(&tmp.front(), target > 0.5, weight,
00297                                   calib->useForTraining(),
00298                                   calib->useForTesting());
00299         }
00300 
00301         return target;
00302 }
00303 
00304 void TrainInterceptor::finish(bool save)
00305 {
00306         proc->doTrainEnd();
00307 
00308         edm::LogInfo("MVATrainer")
00309                 << "... processor \"" << (const char*)proc->getName()
00310                 << "\" training iteration done.";
00311 
00312         if (proc->isTrained()) {
00313                 edm::LogInfo("MVATrainer")
00314                         << "* Completed training of \""
00315                         << (const char*)proc->getName() << "\".";
00316 
00317                 if (save)
00318                         proc->save();
00319         }
00320 }
00321 
00322 // implementation for MVATrainerComputer
00323 
00324 MVATrainerComputer::MVATrainerComputer(const std::vector<Interceptor>
00325                                                 &interceptors, bool autoSave,
00326                                        UInt_t seed, double split) :
00327         interceptors(interceptors), nConfigured(0), doAutoSave(autoSave),
00328         random(seed), split(split)
00329 {
00330         for(std::vector<Interceptor>::const_iterator iter =
00331                 interceptors.begin(); iter != interceptors.end(); ++iter)
00332                 iter->second->setCalibration(this);
00333 }
00334 
00335 MVATrainerComputer::~MVATrainerComputer()
00336 {
00337         done();
00338         
00339         for(std::vector<Interceptor>::const_iterator iter =
00340                 interceptors.begin(); iter != interceptors.end(); ++iter)
00341                 delete iter->second;
00342 }
00343 
00344 std::vector<Calibration::VarProcessor*>
00345 MVATrainerComputer::getProcessors() const
00346 {
00347         std::vector<Calibration::VarProcessor*> processors =
00348                         Calibration::MVAComputer::getProcessors();
00349 
00350         for(std::vector<Interceptor>::const_iterator iter =
00351                 interceptors.begin(); iter != interceptors.end(); ++iter)
00352 
00353                 processors.insert(processors.begin() + iter->first,
00354                                   1, iter->second);
00355 
00356         return processors;
00357 }
00358 
00359 void MVATrainerComputer::initFlags(std::vector<Variable::Flags> &flags) const
00360 {
00361         assert(flags.size() == this->flags.size());
00362         flags = this->flags;
00363 }
00364 
00365 void MVATrainerComputer::configured(BaseInterceptor *interceptor) const
00366 {
00367         nConfigured++;
00368         if (isConfigured())
00369                 for(std::vector<Interceptor>::const_iterator iter =
00370                                                 interceptors.begin();
00371                     iter != interceptors.end(); ++iter)
00372                         iter->second->init();
00373 }
00374 
00375 void MVATrainerComputer::next()
00376 {
00377         splitResult = random.Uniform(1.0) >= split;
00378 }
00379 
00380 void MVATrainerComputer::done()
00381 {
00382         if (isConfigured()) {
00383                 for(std::vector<Interceptor>::const_iterator iter =
00384                                                 interceptors.begin();
00385                     iter != interceptors.end(); ++iter)
00386                         iter->second->finish(doAutoSave);
00387                 nConfigured = 0;
00388         }
00389 }
00390 
00391 // implementation for MVATrainer
00392 
00393 const AtomicId MVATrainer::kTargetId("__TARGET__");
00394 const AtomicId MVATrainer::kWeightId("__WEIGHT__");
00395 
00396 static const AtomicId kOutputId("__OUTPUT__");
00397 
00398 static bool isMagic(AtomicId id)
00399 {
00400         return id == MVATrainer::kTargetId ||
00401                id == MVATrainer::kWeightId ||
00402                id == kOutputId;
00403 }
00404 
00405 static std::string escape(const std::string &in)
00406 {
00407         std::string result("'");
00408         for(std::string::const_iterator iter = in.begin();
00409             iter != in.end(); ++iter) {
00410                 switch(*iter) {
00411                     case '\'':
00412                         result += "'\\''";
00413                         break;
00414                     default:
00415                         result += *iter;
00416                 }
00417         }
00418         result += '\'';
00419         return result;
00420 }
00421 
00422 MVATrainer::MVATrainer(const std::string &fileName, bool useXSLT,
00423         const char *styleSheet) :
00424         input(0), output(0), name("MVATrainer"),
00425         doAutoSave(true), doCleanup(false),
00426         doMonitoring(false), randomSeed(65539), crossValidation(0.0)
00427 {
00428         if (useXSLT) {
00429                 std::string sheet;
00430                 if (!styleSheet)
00431                         sheet = edm::FileInPath(
00432                                 "PhysicsTools/MVATrainer/data/MVATrainer.xsl")
00433                                 .fullPath();
00434                 else
00435                         sheet = styleSheet;
00436 
00437                 std::string preproc = "xsltproc --xinclude " + escape(sheet) +
00438                                       " " + escape(fileName);
00439                 xml.reset(new XMLDocument(fileName, preproc));
00440         } else
00441                 xml.reset(new XMLDocument(fileName));
00442 
00443         DOMNode *node = xml->getRootNode();
00444 
00445         if (std::strcmp(XMLSimpleStr(node->getNodeName()), "MVATrainer") != 0)
00446                 throw cms::Exception("MVATrainer")
00447                         << "Invalid XML root node." << std::endl;
00448 
00449         enum State {
00450                 STATE_GENERAL,
00451                 STATE_FIRST,
00452                 STATE_MIDDLE,
00453                 STATE_LAST
00454         } state = STATE_GENERAL;
00455 
00456         for(node = node->getFirstChild();
00457             node; node = node->getNextSibling()) {
00458                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00459                         continue;
00460 
00461                 std::string name = XMLSimpleStr(node->getNodeName());
00462                 DOMElement *elem = static_cast<DOMElement*>(node);
00463 
00464                 switch(state) {
00465                     case STATE_GENERAL: {
00466                         if (name != "general")
00467                                 throw cms::Exception("MVATrainer")
00468                                         << "Expected general config as first "
00469                                            "tag." << std::endl;
00470 
00471                         for(DOMNode *subNode = elem->getFirstChild();
00472                             subNode; subNode = subNode->getNextSibling()) {
00473                                 if (subNode->getNodeType() !=
00474                                     DOMNode::ELEMENT_NODE)
00475                                         continue;
00476 
00477                                 if (std::strcmp(XMLSimpleStr(
00478                                         subNode->getNodeName()), "option") != 0)
00479                                         throw cms::Exception("MVATrainer")
00480                                                 << "Expected option tag."
00481                                                 << std::endl;
00482 
00483                                 elem = static_cast<DOMElement*>(subNode);
00484                                 name = XMLDocument::readAttribute<std::string>(
00485                                                                 elem, "name");
00486                                 std::string content = XMLSimpleStr(
00487                                                 elem->getTextContent());
00488 
00489                                 if (name == "id")
00490                                         this->name = content;
00491                                 else if (name == "trainfiles")
00492                                         trainFileMask = content;
00493                                 else
00494                                         throw cms::Exception("MVATrainer")
00495                                                 << "Unknown option \""
00496                                                 << name << "\"." << std::endl;
00497                         }
00498 
00499                         state = STATE_FIRST;
00500                     }   break;
00501                     case STATE_FIRST: {
00502                         if (name != "input")
00503                                 throw cms::Exception("MVATrainer")
00504                                         << "Expected input config as second "
00505                                            "tag." << std::endl;
00506 
00507                         AtomicId id = XMLDocument::readAttribute<std::string>(
00508                                                                 elem, "id");
00509                         input = new Source(id, true);
00510                         input->getOutputs().append(
00511                                 createVariable(input, kTargetId,
00512                                                Variable::FLAG_NONE),
00513                                 SourceVariableSet::kTarget);
00514                         input->getOutputs().append(
00515                                 createVariable(input, kWeightId,
00516                                                Variable::FLAG_OPTIONAL),
00517                                 SourceVariableSet::kWeight);
00518                         sources.insert(std::make_pair(id, input));
00519                         fillOutputVars(input->getOutputs(), input, elem);
00520 
00521                         state = STATE_MIDDLE;
00522                     }   break;
00523                     case STATE_MIDDLE: {
00524                         if (name == "output") {
00525                                 AtomicId zero;
00526                                 output = new TrainProcessor("output",
00527                                                             &zero, this);
00528                                 fillInputVars(output->getInputs(), elem);
00529                                 state = STATE_LAST;
00530                                 continue;
00531                         } else if (name != "processor")
00532                                 throw cms::Exception("MVATrainer")
00533                                         << "Unexpected tag after input "
00534                                            "config." << std::endl;
00535 
00536                         AtomicId id = XMLDocument::readAttribute<std::string>(
00537                                                                 elem, "id");
00538                         std::string name =
00539                                 XMLDocument::readAttribute<std::string>(
00540                                         elem, "name");
00541 
00542                         makeProcessor(elem, id, name.c_str());
00543                     }   break;
00544                     case STATE_LAST:
00545                         throw cms::Exception("MVATrainer")
00546                                 << "Unexpected tag found after output."
00547                                 << std::endl;
00548                         break;
00549                 }
00550         }
00551 
00552         if (state == STATE_FIRST)
00553                 throw cms::Exception("MVATrainer")
00554                         << "Expected input variable config." << std::endl;
00555         else if (state == STATE_MIDDLE)
00556                 throw cms::Exception("MVATrainer")
00557                         << "Expected output variable config." << std::endl;
00558 
00559         if (trainFileMask.empty())
00560                 trainFileMask = this->name + "_%s%s.%s";
00561 }
00562 
00563 MVATrainer::~MVATrainer()
00564 {
00565         if (monitoring.get())
00566                 monitoring->write();
00567 
00568         for(std::map<AtomicId, Source*>::const_iterator iter = sources.begin();
00569             iter != sources.end(); iter++) {
00570                 TrainProcessor *proc =
00571                                 dynamic_cast<TrainProcessor*>(iter->second);
00572 
00573                 if (proc && doCleanup)
00574                         proc->cleanup();
00575 
00576                 delete iter->second;
00577         }
00578         delete output;
00579         std::for_each(variables.begin(), variables.end(),
00580                       deleter<SourceVariable>());
00581 }
00582 
00583 void MVATrainer::loadState()
00584 {
00585         for(std::vector<AtomicId>::const_iterator iter =
00586                                                 this->processors.begin();
00587             iter != this->processors.end(); iter++) {
00588                 std::map<AtomicId, Source*>::const_iterator pos =
00589                                                         sources.find(*iter);
00590                 assert(pos != sources.end());
00591                 TrainProcessor *source =
00592                                 dynamic_cast<TrainProcessor*>(pos->second);
00593                 assert(source);
00594 
00595                 if (source->load())
00596                         edm::LogInfo("MVATrainer")
00597                                 << source->getId() << " configuration for \""
00598                                 << (const char*)source->getName()
00599                                 << "\" loaded from file.";
00600         }
00601 }
00602 
00603 void MVATrainer::saveState()
00604 {
00605         doCleanup = false;
00606 
00607         for(std::vector<AtomicId>::const_iterator iter =
00608                                                 this->processors.begin();
00609             iter != this->processors.end(); iter++) {
00610                 std::map<AtomicId, Source*>::const_iterator pos =
00611                                                         sources.find(*iter);
00612                 assert(pos != sources.end());
00613                 TrainProcessor *source =
00614                                 dynamic_cast<TrainProcessor*>(pos->second);
00615                 assert(source);
00616 
00617                 if (source->isTrained())
00618                         source->save();
00619         }
00620 }
00621 
00622 void MVATrainer::makeProcessor(DOMElement *elem, AtomicId id, const char *name)
00623 {
00624         DOMElement *xmlInput = 0;
00625         DOMElement *xmlConfig = 0;
00626         DOMElement *xmlOutput = 0;
00627         DOMElement *xmlData = 0;
00628 
00629         static struct NameExpect {
00630                 const char      *tag;
00631                 bool            mandatory;
00632                 DOMElement      **elem;
00633         } const expect[] = {
00634                 { "input",      true,   &xmlInput },
00635                 { "config",     true,   &xmlConfig },
00636                 { "output",     true,   &xmlOutput },
00637                 { "data",       false,  &xmlData },
00638                 { 0, }
00639         };
00640 
00641         const NameExpect *cur = expect;
00642         for(DOMNode *node = elem->getFirstChild();
00643             node; node = node->getNextSibling()) {
00644                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00645                         continue;
00646 
00647                 std::string tag = XMLSimpleStr(node->getNodeName());
00648                 DOMElement *elem = static_cast<DOMElement*>(node);
00649 
00650                 if (!cur->tag)
00651                         throw cms::Exception("MVATrainer")
00652                                 << "Superfluous tag " << tag
00653                                 << "encountered in processor." << std::endl;
00654                 else if (tag != cur->tag && cur->mandatory)
00655                         throw cms::Exception("MVATrainer")
00656                                 << "Expected tag " << cur->tag << ", got "
00657                                 << tag << " instead in processor."
00658                                 << std::endl;
00659                 else if (tag != cur->tag) {
00660                         cur++;
00661                         continue;
00662                 }
00663                 *(cur++)->elem = elem;
00664         }
00665 
00666         while(cur->tag && !cur->mandatory)
00667                 cur++;
00668         if (cur->tag)
00669                 throw cms::Exception("MVATrainer")
00670                         << "Unexpected end of processor configuration, "
00671                         << "expected tag " << cur->tag << "." << std::endl;
00672 
00673         std::auto_ptr<TrainProcessor> proc(
00674                                 TrainProcessor::create(name, &id, this));
00675         if (!proc.get())
00676                 throw cms::Exception("MVATrainer")
00677                         << "Variable processor trainer " << name
00678                         << " could not be instantiated. Most likely because"
00679                            " the trainer plugin for \"" << name << "\""
00680                            " does not exist." << std::endl;
00681 
00682         if (sources.find(id) != sources.end())
00683                 throw cms::Exception("MVATrainer")
00684                         << "Duplicate variable processor id "
00685                         << (const char*)id << "."
00686                         << std::endl;
00687 
00688         fillInputVars(proc->getInputs(), xmlInput);
00689         fillOutputVars(proc->getOutputs(), proc.get(), xmlOutput);
00690 
00691         edm::LogInfo("MVATrainer")
00692                 << "Configuring " << (const char*)proc->getId()
00693                 << " \"" << (const char*)proc->getName() << "\".";
00694         proc->configure(xmlConfig);
00695 
00696         sources.insert(std::make_pair(id, proc.release()));
00697         processors.push_back(id);
00698 }
00699 
00700 std::string MVATrainer::trainFileName(const TrainProcessor *proc,
00701                                       const std::string &ext,
00702                                       const std::string &arg) const
00703 {
00704         std::string arg_ = arg.size() > 0 ? ("_" + arg) : "";
00705         return stdStringPrintf(trainFileMask.c_str(),
00706                                (const char*)proc->getName(),
00707                                arg_.c_str(), ext.c_str());
00708 }
00709 
00710 TrainerMonitoring::Module *MVATrainer::bookMonitor(const std::string &name)
00711 {
00712         if (!doMonitoring)
00713                 return 0;
00714 
00715         if (!monitoring.get()) {
00716                 std::string fileName = 
00717                         stdStringPrintf(trainFileMask.c_str(),
00718                                         "monitoring", "", "root");
00719                 monitoring.reset(new TrainerMonitoring(fileName));
00720         }
00721 
00722         return monitoring->book(name);
00723 }
00724 
00725 SourceVariable *MVATrainer::getVariable(AtomicId source, AtomicId name) const
00726 {
00727         std::map<AtomicId, Source*>::const_iterator pos = sources.find(source);
00728         if (pos == sources.end())
00729                 return 0;
00730 
00731         return pos->second->getOutput(name);
00732 }
00733 
00734 SourceVariable *MVATrainer::createVariable(Source *source, AtomicId name,
00735                                            Variable::Flags flags)
00736 {
00737         SourceVariable *var = getVariable(source->getName(), name);
00738         if (var)
00739                 return 0;
00740 
00741         var = new SourceVariable(source, name, flags);
00742         variables.push_back(var);
00743         return var;
00744 }
00745 
00746 void MVATrainer::fillInputVars(SourceVariableSet &vars,
00747                                XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
00748 {
00749         std::vector<SourceVariable*> tmp;
00750         SourceVariable *target = 0;
00751         SourceVariable *weight = 0;
00752 
00753         for(DOMNode *node = xml->getFirstChild(); node;
00754             node = node->getNextSibling()) {
00755                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00756                         continue;
00757 
00758                 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
00759                         throw cms::Exception("MVATrainer")
00760                                 << "Invalid input variable node." << std::endl;
00761 
00762                 DOMElement *elem = static_cast<DOMElement*>(node);
00763 
00764                 AtomicId source = XMLDocument::readAttribute<std::string>(
00765                                                         elem, "source");
00766                 AtomicId name = XMLDocument::readAttribute<std::string>(
00767                                                         elem, "name");
00768 
00769                 SourceVariable *var = getVariable(source, name);
00770                 if (!var)
00771                         throw cms::Exception("MVATrainer")
00772                                 << "Input variable " << (const char*)source
00773                                 << ":" << (const char*)name
00774                                 << " not found." << std::endl;
00775 
00776                 if (XMLDocument::readAttribute<bool>(elem, "target", false)) {
00777                         if (target)
00778                                 throw cms::Exception("MVATrainer")
00779                                         << "Target variable defined twice"
00780                                         << std::endl;
00781                         target = var;
00782                 }
00783                 if (XMLDocument::readAttribute<bool>(elem, "weight", false)) {
00784                         if (weight)
00785                                 throw cms::Exception("MVATrainer")
00786                                         << "Weight variable defined twice"
00787                                         << std::endl;
00788                         weight = var;
00789                 }
00790 
00791                 tmp.push_back(var);
00792         }
00793 
00794         if (!weight) {
00795                 weight = input->getOutput(kWeightId);
00796                 assert(weight);
00797                 tmp.insert(tmp.begin() +
00798                                 (target == input->getOutput(kTargetId)),
00799                            1, weight);
00800         }
00801         if (!target) {
00802                 target = input->getOutput(kTargetId);
00803                 assert(target);
00804                 tmp.insert(tmp.begin(), 1, target);
00805         }
00806 
00807         unsigned int n = 0;
00808         for(std::vector<SourceVariable*>::const_iterator iter = variables.begin();
00809             iter != variables.end(); iter++) {
00810                 std::vector<SourceVariable*>::const_iterator pos =
00811                         std::find(tmp.begin(), tmp.end(), *iter);
00812                 if (pos == tmp.end())
00813                         continue;
00814 
00815                 SourceVariableSet::Magic magic;
00816                 if (*iter == target)
00817                         magic = SourceVariableSet::kTarget;
00818                 else if (*iter == weight)
00819                         magic = SourceVariableSet::kWeight;
00820                 else
00821                         magic = SourceVariableSet::kRegular;
00822 
00823                 if (vars.append(*iter, magic, pos - tmp.begin())) {
00824                         AtomicId source = (*iter)->getSource()->getName();
00825                         AtomicId name = (*iter)->getName();
00826                         throw cms::Exception("MVATrainer")
00827                                 << "Input variable " << (const char*)source
00828                                 << ":" << (const char*)name
00829                                 << " defined twice." << std::endl;
00830                 }
00831 
00832                 n++;
00833         }
00834 
00835         assert(tmp.size() == n);
00836 }
00837 
00838 void MVATrainer::fillOutputVars(SourceVariableSet &vars, Source *source,
00839                                 XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
00840 {
00841         for(DOMNode *node = xml->getFirstChild(); node;
00842             node = node->getNextSibling()) {
00843                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00844                         continue;
00845 
00846                 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
00847                         throw cms::Exception("MVATrainer")
00848                                 << "Invalid output variable node."
00849                                 << std::endl;
00850 
00851                 DOMElement *elem = static_cast<DOMElement*>(node);
00852 
00853                 AtomicId name = XMLDocument::readAttribute<std::string>(
00854                                                         elem, "name");
00855                 if (!name)
00856                         throw cms::Exception("MVATrainer")
00857                                 << "Output variable tag missing name."
00858                                 << std::endl;
00859                 if (isMagic(name))
00860                         throw cms::Exception("MVATrainer")
00861                                 << "Cannot use magic variable names in output."
00862                                 << std::endl;
00863 
00864                 Variable::Flags flags = Variable::FLAG_NONE;
00865 
00866                 if (XMLDocument::readAttribute<bool>(elem, "optional", true))
00867                         flags = (PhysicsTools::Variable::Flags)
00868                                 (flags | Variable::FLAG_OPTIONAL);
00869 
00870                 if (XMLDocument::readAttribute<bool>(elem, "multiple", true))
00871                         flags = (PhysicsTools::Variable::Flags)
00872                                 (flags | Variable::FLAG_MULTIPLE);
00873 
00874                 SourceVariable *var = createVariable(source, name, flags);
00875                 if (!var || vars.append(var))
00876                         throw cms::Exception("MVATrainer")
00877                                 << "Output variable "
00878                                 << (const char*)source->getName()
00879                                 << ":" << (const char*)name
00880                                 << " defined twice." << std::endl;
00881         }
00882 }
00883 
00884 void
00885 MVATrainer::connectProcessors(Calibration::MVAComputer *calib,
00886                               const std::vector<CalibratedProcessor> &procs,
00887                               bool withTarget) const
00888 {
00889         std::map<SourceVariable*, unsigned int> vars;
00890         unsigned int size = 0;
00891 
00892         MVATrainerComputer *trainCalib =
00893                         dynamic_cast<MVATrainerComputer*>(calib);
00894 
00895         for(unsigned int i = 0;
00896             i < input->getOutputs().size(true); i++) {
00897                 if (i < 2 && !withTarget)
00898                         continue;
00899 
00900                 SourceVariable *var = variables[i];
00901                 vars[var] = size++;
00902 
00903                 Calibration::Variable calibVar;
00904                 calibVar.name = (const char*)var->getName();
00905                 calib->inputSet.push_back(calibVar);
00906                 if (trainCalib)
00907                         trainCalib->addFlag(var->getFlags());
00908         }
00909 
00910         for(std::vector<CalibratedProcessor>::const_iterator iter =
00911                                 procs.begin(); iter != procs.end(); iter++) {
00912                 bool isInterceptor = dynamic_cast<BaseInterceptor*>(
00913                                                         iter->calib) != 0;
00914 
00915                 BitSet inputSet(size);
00916 
00917                 unsigned int last = 0;
00918                 std::vector<SourceVariable*> inoutVars;
00919                 if (iter->processor)
00920                         inoutVars = iter->processor->getInputs().get(
00921                                                                 isInterceptor);
00922                 for(std::vector<SourceVariable*>::const_iterator iter2 =
00923                         inoutVars.begin(); iter2 != inoutVars.end(); iter2++) {
00924                         std::map<SourceVariable*,
00925                                  unsigned int>::const_iterator pos =
00926                                                         vars.find(*iter2);
00927 
00928                         assert(pos != vars.end());
00929 
00930                         if (pos->second < last)
00931                                 throw cms::Exception("MVATrainer")
00932                                         << "Input variables not declared "
00933                                            "in order of appearance in \""
00934                                         << (const char*)iter->processor->getName()
00935                                         << "\"." << std::endl;
00936 
00937                         inputSet[last = pos->second] = true;
00938                 }
00939 
00940                 assert(!isInterceptor || withTarget);
00941 
00942                 iter->calib->inputVars = Calibration::convert(inputSet);
00943 
00944                 calib->output = size;
00945 
00946                 if (isInterceptor) {
00947                         size++;
00948                         continue;
00949                 }
00950 
00951                 calib->addProcessor(iter->calib);
00952 
00953                 inoutVars = iter->processor->getOutputs().get();
00954                 for(std::vector<SourceVariable*>::const_iterator iter =
00955                         inoutVars.begin(); iter != inoutVars.end(); iter++) {
00956 
00957                         vars[*iter] = size++;
00958                 }
00959         }
00960 
00961         if (output->getInputs().size() != 1)
00962                 throw cms::Exception("MVATrainer")
00963                         << "Exactly one output variable has to be specified."
00964                         << std::endl;
00965 
00966         SourceVariable *outVar = output->getInputs().get()[0];
00967         std::map<SourceVariable*, unsigned int>::const_iterator pos =
00968                                                         vars.find(outVar);
00969         if (pos != vars.end())
00970                 calib->output = pos->second;
00971 }
00972 
00973 Calibration::MVAComputer *
00974 MVATrainer::makeTrainCalibration(const AtomicId *compute,
00975                                  const AtomicId *train) const
00976 {
00977         std::map<AtomicId, TrainInterceptor*> interceptors;
00978         std::vector<MVATrainerComputer::Interceptor> baseInterceptors;
00979         std::vector<CalibratedProcessor> processors;
00980 
00981         BaseInterceptor *interceptor = new InitInterceptor;
00982         baseInterceptors.push_back(std::make_pair(0, interceptor));
00983         processors.push_back(CalibratedProcessor(0, interceptor));
00984 
00985         for(const AtomicId *iter = train; *iter; iter++) {
00986                 TrainProcessor *source;
00987                 if (*iter == kOutputId)
00988                         source = output;
00989                 else {
00990                         std::map<AtomicId, Source*>::const_iterator pos =
00991                                                         sources.find(*iter);
00992                         assert(pos != sources.end());
00993                         source = dynamic_cast<TrainProcessor*>(pos->second);
00994                 }
00995                 assert(source);
00996 
00997                 interceptors[*iter] = new TrainInterceptor(source);
00998         }
00999 
01000         auto_cleaner<Calibration::VarProcessor> autoClean;
01001 
01002         std::set<AtomicId> done;
01003         for(const AtomicId *iter = compute; *iter; iter++) {
01004                 if (done.erase(*iter))
01005                         continue;
01006 
01007                 std::map<AtomicId, Source*>::const_iterator pos =
01008                                                         sources.find(*iter);
01009                 assert(pos != sources.end());
01010                 TrainProcessor *source =
01011                                 dynamic_cast<TrainProcessor*>(pos->second);
01012                 assert(source);
01013                 assert(source->isTrained());
01014 
01015                 Calibration::VarProcessor *proc = source->getCalibration();
01016                 if (!proc)
01017                         continue;
01018 
01019                 autoClean.add(proc);
01020                 processors.push_back(CalibratedProcessor(source, proc));
01021 
01022                 Calibration::ProcForeach *looper =
01023                                 dynamic_cast<Calibration::ProcForeach*>(proc);
01024                 if (looper) {
01025                         std::vector<AtomicId>::const_iterator pos2 =
01026                                 std::find(this->processors.begin(),
01027                                           this->processors.end(), *iter);
01028                         assert(pos2 != this->processors.end());
01029                         ++pos2;
01030                         unsigned int n = 0;
01031                         for(int i = 0; i < (int)looper->nProcs; ++i, ++pos2) {
01032                                 assert(pos2 != this->processors.end());
01033 
01034                                 const AtomicId *iter2 = compute;
01035                                 while(*iter2) {
01036                                         if (*iter2 == *pos2)
01037                                                 break;
01038                                         iter2++;
01039                                 }
01040 
01041                                 if (*iter2) {
01042                                         n++;
01043                                         done.insert(*iter2);
01044                                         pos = sources.find(*iter2);
01045                                         assert(pos != sources.end());
01046                                         TrainProcessor *source =
01047                                                 dynamic_cast<TrainProcessor*>(
01048                                                                 pos->second);
01049                                         assert(source);
01050                                         assert(source->isTrained());
01051 
01052                                         proc = source->getCalibration();
01053                                         if (proc) {
01054                                                 autoClean.add(proc);
01055                                                 processors.push_back(
01056                                                         CalibratedProcessor(
01057                                                                 source, proc));
01058                                         }
01059                                 }
01060 
01061                                 std::map<AtomicId, TrainInterceptor*>::iterator
01062                                                 pos3 = interceptors.find(*pos2);
01063                                 if (pos3 != interceptors.end()) {
01064                                         n++;
01065                                         baseInterceptors.push_back(
01066                                                 std::make_pair(processors.size(),
01067                                                                pos3->second));
01068                                         processors.push_back(
01069                                                 CalibratedProcessor(
01070                                                         pos3->second->getProcessor(),
01071                                                         pos3->second));
01072                                         interceptors.erase(pos3);
01073                                 }
01074                         }
01075 
01076                         looper->nProcs = n;
01077                         if (!n) {
01078                                 baseInterceptors.pop_back();
01079                                 processors.pop_back();
01080                         }
01081                 }
01082         }
01083 
01084         for(std::map<AtomicId, TrainInterceptor*>::const_iterator iter =
01085                 interceptors.begin(); iter != interceptors.end(); ++iter) {
01086 
01087                 TrainProcessor *proc = iter->second->getProcessor();
01088                 baseInterceptors.push_back(std::make_pair(processors.size(),
01089                                                           iter->second));
01090                 processors.push_back(CalibratedProcessor(proc, iter->second));
01091         }
01092 
01093         std::auto_ptr<Calibration::MVAComputer> calib(
01094                 new MVATrainerComputer(baseInterceptors, doAutoSave,
01095                                        randomSeed, crossValidation));
01096 
01097         connectProcessors(calib.get(), processors, true);
01098 
01099         return calib.release();
01100 }
01101 
01102 void MVATrainer::doneTraining(Calibration::MVAComputer *trainCalibration) const
01103 {
01104         MVATrainerComputer *calib =
01105                         dynamic_cast<MVATrainerComputer*>(trainCalibration);
01106 
01107         if (!calib)
01108                 throw cms::Exception("MVATrainer")
01109                         << "Invalid training calibration passed to "
01110                            "doneTraining()" << std::endl;
01111 
01112         calib->done();
01113 }
01114 
01115 std::vector<AtomicId> MVATrainer::findFinalProcessors() const
01116 {
01117         std::set<Source*> toCheck;
01118         toCheck.insert(output);
01119 
01120         std::set<Source*> done;
01121         while(!toCheck.empty()) {
01122                 Source *source = *toCheck.begin();
01123                 toCheck.erase(toCheck.begin());
01124 
01125                 std::vector<SourceVariable*> inputs = source->inputs.get();
01126                 for(std::vector<SourceVariable*>::const_iterator iter =
01127                                 inputs.begin(); iter != inputs.end(); ++iter) {
01128                         source = (*iter)->getSource();
01129                         if (done.insert(source).second)
01130                                 toCheck.insert(source);
01131                 }
01132         }
01133 
01134         std::vector<AtomicId> result;
01135         for(std::vector<AtomicId>::const_iterator iter = processors.begin();
01136             iter != processors.end(); ++iter) {
01137                 std::map<AtomicId, Source*>::const_iterator pos =
01138                                                         sources.find(*iter);
01139                 if (pos != sources.end() && done.count(pos->second))
01140                         result.push_back(*iter);
01141         }
01142 
01143         return result;
01144 }
01145 
01146 Calibration::MVAComputer *MVATrainer::getCalibration() const
01147 {
01148         std::vector<CalibratedProcessor> processors;
01149 
01150         std::auto_ptr<Calibration::MVAComputer> calib(
01151                                                 new Calibration::MVAComputer);
01152 
01153         std::vector<AtomicId> used = findFinalProcessors();
01154         for(std::vector<AtomicId>::const_iterator iter = used.begin();
01155             iter != used.end(); iter++) {
01156                 std::map<AtomicId, Source*>::const_iterator pos =
01157                                                         sources.find(*iter);
01158                 assert(pos != sources.end());
01159                 TrainProcessor *source =
01160                                 dynamic_cast<TrainProcessor*>(pos->second);
01161                 assert(source);
01162                 if (!source->isTrained())
01163                         return 0;
01164 
01165                 Calibration::VarProcessor *proc = source->getCalibration();
01166                 if (!proc)
01167                         continue;
01168 
01169                 Calibration::ProcForeach *foreach =
01170                                 dynamic_cast<Calibration::ProcForeach*>(proc);
01171                 if (foreach) {
01172                         std::vector<AtomicId>::const_iterator begin =
01173                                 std::find(this->processors.begin(),
01174                                           this->processors.end(), *iter);
01175                         assert(this->processors.end() - begin >
01176                                (int)(foreach->nProcs + 1));
01177                         ++begin;
01178                         std::vector<AtomicId>::const_iterator end =
01179                                                 begin + foreach->nProcs;
01180                         foreach->nProcs = 0;
01181                         for(std::vector<AtomicId>::const_iterator iter2 =
01182                                         iter; iter2 != used.end(); ++iter2)
01183                                 if (std::find(begin, end, *iter2) != end)
01184                                         foreach->nProcs++;
01185                 }
01186 
01187                 processors.push_back(CalibratedProcessor(source, proc));
01188         }
01189 
01190         connectProcessors(calib.get(), processors, false);
01191 
01192         return calib.release();
01193 }
01194 
01195 void MVATrainer::findUntrainedComputers(std::vector<AtomicId> &compute,
01196                                         std::vector<AtomicId> &train) const
01197 {
01198         compute.clear();
01199         train.clear();
01200 
01201         std::set<Source*> trainedSources;
01202         trainedSources.insert(input);
01203 
01204         for(std::vector<AtomicId>::const_iterator iter =
01205                 processors.begin(); iter != processors.end(); iter++) {
01206                 std::map<AtomicId, Source*>::const_iterator pos =
01207                                                         sources.find(*iter);
01208                 assert(pos != sources.end());
01209                 TrainProcessor *proc =
01210                                 dynamic_cast<TrainProcessor*>(pos->second);
01211                 assert(proc);
01212 
01213                 bool trainedDeps = true;
01214                 std::vector<SourceVariable*> inputVars =
01215                                         proc->getInputs().get();
01216                 for(std::vector<SourceVariable*>::const_iterator iter2 =
01217                         inputVars.begin(); iter2 != inputVars.end(); iter2++) {
01218                         if (trainedSources.find((*iter2)->getSource())
01219                             == trainedSources.end()) {
01220                                 trainedDeps = false;
01221                                 break;
01222                         }
01223                 }
01224 
01225                 if (!trainedDeps)
01226                         continue;
01227 
01228                 if (proc->isTrained()) {
01229                         trainedSources.insert(proc);
01230                         compute.push_back(proc->getName());
01231                 } else
01232                         train.push_back(proc->getName());
01233         }
01234 
01235         if (doMonitoring && !output->isTrained() &&
01236             trainedSources.find(output->getInputs().get()[0]->getSource())
01237                                                 != trainedSources.end())
01238                 train.push_back(kOutputId);
01239 }
01240 
01241 Calibration::MVAComputer *MVATrainer::getTrainCalibration() const
01242 {
01243         std::vector<AtomicId> compute, train;
01244         findUntrainedComputers(compute, train);
01245 
01246         if (train.empty())
01247                 return 0;
01248 
01249         compute.push_back(0);
01250         train.push_back(0);
01251 
01252         return makeTrainCalibration(&compute.front(), &train.front());
01253 }
01254 
01255 } // namespace PhysicsTools