CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_6_1_2_SLHC4_patch1/src/PhysicsTools/MVATrainer/interface/MVATrainer.h

Go to the documentation of this file.
00001 #ifndef PhysicsTools_MVATrainer_MVATrainer_h
00002 #define PhysicsTools_MVATrainer_MVATrainer_h
00003 
00004 #include <memory>
00005 #include <string>
00006 #include <map>
00007 
00008 #include <xercesc/dom/DOM.hpp>
00009 
00010 #include "FWCore/Utilities/interface/Exception.h"
00011 
00012 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00013 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00014 #include "PhysicsTools/MVAComputer/interface/Variable.h"
00015 #include "PhysicsTools/MVAComputer/interface/MVAComputer.h"
00016 
00017 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00018 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00019 #include "PhysicsTools/MVATrainer/interface/SourceVariableSet.h"
00020 #include "PhysicsTools/MVATrainer/interface/TrainerMonitoring.h"
00021 
00022 namespace PhysicsTools {
00023 
00024 class Source;
00025 class TrainProcessor;
00026 
00027 class MVATrainer {
00028     public:
00029         MVATrainer(const std::string &fileName, bool useXSLT = false,
00030                    const char *styleSheet = 0);
00031         ~MVATrainer();
00032 
00033         inline void setAutoSave(bool autoSave) { doAutoSave = autoSave; }
00034         inline void setCleanup(bool cleanup) { doCleanup = cleanup; }
00035         inline void setMonitoring(bool monitoring) { doMonitoring = monitoring; }
00036         inline void setRandomSeed(UInt_t seed) { randomSeed = seed; }
00037         inline void setCrossValidation(double split) { crossValidation = split; }
00038 
00039         void loadState();
00040         void saveState();
00041 
00042         Calibration::MVAComputer *getTrainCalibration() const;
00043         void doneTraining(Calibration::MVAComputer *trainCalibration) const;
00044 
00045         Calibration::MVAComputer *getCalibration() const;
00046 
00047         // used by TrainProcessors
00048 
00049         std::string trainFileName(const TrainProcessor *proc,
00050                                   const std::string &ext,
00051                                   const std::string &arg = "") const;
00052 
00053         inline const std::string &getName() const { return name; }
00054 
00055         TrainerMonitoring::Module *bookMonitor(const std::string &name);
00056 
00057         // constants
00058 
00059         static const AtomicId kTargetId;
00060         static const AtomicId kWeightId;
00061 
00062     private:
00063         SourceVariable *getVariable(AtomicId source, AtomicId name) const;
00064 
00065         SourceVariable *createVariable(Source *source, AtomicId name,
00066                                        Variable::Flags flags);
00067 
00068         struct CalibratedProcessor {
00069                 CalibratedProcessor(TrainProcessor *processor,
00070                                     Calibration::VarProcessor *calib) :
00071                         processor(processor), calib(calib) {}
00072 
00073                 TrainProcessor                  *processor;
00074                 Calibration::VarProcessor       *calib;
00075         };
00076 
00077         void fillInputVars(SourceVariableSet &vars,
00078                            XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml);
00079 
00080         void fillOutputVars(SourceVariableSet &vars, Source *source,
00081                             XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml);
00082 
00083         void makeProcessor(XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *elem,
00084                            AtomicId id, const char *name);
00085 
00086         void connectProcessors(Calibration::MVAComputer *calib,
00087                                const std::vector<CalibratedProcessor> &procs,
00088                                bool withTarget) const;
00089 
00090         Calibration::MVAComputer *
00091         makeTrainCalibration(const AtomicId *compute,
00092                              const AtomicId *train) const;
00093 
00094         void
00095         findUntrainedComputers(std::vector<AtomicId> &compute,
00096                                std::vector<AtomicId> &train) const;
00097 
00098         std::vector<AtomicId> findFinalProcessors() const;
00099 
00100         std::map<AtomicId, Source*>             sources;
00101         std::vector<SourceVariable*>            variables;
00102         std::vector<AtomicId>                   processors;
00103         Source                                  *input;
00104         TrainProcessor                          *output;
00105 
00106         std::auto_ptr<TrainerMonitoring>        monitoring;
00107         std::auto_ptr<XMLDocument>              xml;
00108         std::string                             trainFileMask;
00109         std::string                             name;
00110         bool                                    doAutoSave;
00111         bool                                    doCleanup;
00112         bool                                    doMonitoring;
00113 
00114         UInt_t                                  randomSeed;
00115         double                                  crossValidation;
00116 };
00117 
00118 } // namespace PhysicsTools
00119 
00120 #endif // PhysicsTools_MVATrainer_MVATrainer_h