Go to the documentation of this file.00001 #ifndef PhysicsTools_MVATrainer_MVATrainer_h
00002 #define PhysicsTools_MVATrainer_MVATrainer_h
00003
00004 #include <memory>
00005 #include <string>
00006 #include <map>
00007
00008 #include <xercesc/dom/DOM.hpp>
00009
00010 #include "FWCore/Utilities/interface/Exception.h"
00011
00012 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00013 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00014 #include "PhysicsTools/MVAComputer/interface/Variable.h"
00015 #include "PhysicsTools/MVAComputer/interface/MVAComputer.h"
00016
00017 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00018 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00019 #include "PhysicsTools/MVATrainer/interface/SourceVariableSet.h"
00020 #include "PhysicsTools/MVATrainer/interface/TrainerMonitoring.h"
00021
00022 namespace PhysicsTools {
00023
00024 class Source;
00025 class TrainProcessor;
00026
00027 class MVATrainer {
00028 public:
00029 MVATrainer(const std::string &fileName, bool useXSLT = false,
00030 const char *styleSheet = 0);
00031 ~MVATrainer();
00032
00033 inline void setAutoSave(bool autoSave) { doAutoSave = autoSave; }
00034 inline void setCleanup(bool cleanup) { doCleanup = cleanup; }
00035 inline void setMonitoring(bool monitoring) { doMonitoring = monitoring; }
00036 inline void setRandomSeed(UInt_t seed) { randomSeed = seed; }
00037 inline void setCrossValidation(double split) { crossValidation = split; }
00038
00039 void loadState();
00040 void saveState();
00041
00042 Calibration::MVAComputer *getTrainCalibration() const;
00043 void doneTraining(Calibration::MVAComputer *trainCalibration) const;
00044
00045 Calibration::MVAComputer *getCalibration() const;
00046
00047
00048
00049 std::string trainFileName(const TrainProcessor *proc,
00050 const std::string &ext,
00051 const std::string &arg = "") const;
00052
00053 inline const std::string &getName() const { return name; }
00054
00055 TrainerMonitoring::Module *bookMonitor(const std::string &name);
00056
00057
00058
00059 static const AtomicId kTargetId;
00060 static const AtomicId kWeightId;
00061
00062 private:
00063 SourceVariable *getVariable(AtomicId source, AtomicId name) const;
00064
00065 SourceVariable *createVariable(Source *source, AtomicId name,
00066 Variable::Flags flags);
00067
00068 struct CalibratedProcessor {
00069 CalibratedProcessor(TrainProcessor *processor,
00070 Calibration::VarProcessor *calib) :
00071 processor(processor), calib(calib) {}
00072
00073 TrainProcessor *processor;
00074 Calibration::VarProcessor *calib;
00075 };
00076
00077 void fillInputVars(SourceVariableSet &vars,
00078 XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml);
00079
00080 void fillOutputVars(SourceVariableSet &vars, Source *source,
00081 XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml);
00082
00083 void makeProcessor(XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *elem,
00084 AtomicId id, const char *name);
00085
00086 void connectProcessors(Calibration::MVAComputer *calib,
00087 const std::vector<CalibratedProcessor> &procs,
00088 bool withTarget) const;
00089
00090 Calibration::MVAComputer *
00091 makeTrainCalibration(const AtomicId *compute,
00092 const AtomicId *train) const;
00093
00094 void
00095 findUntrainedComputers(std::vector<AtomicId> &compute,
00096 std::vector<AtomicId> &train) const;
00097
00098 std::vector<AtomicId> findFinalProcessors() const;
00099
00100 std::map<AtomicId, Source*> sources;
00101 std::vector<SourceVariable*> variables;
00102 std::vector<AtomicId> processors;
00103 Source *input;
00104 TrainProcessor *output;
00105
00106 std::auto_ptr<TrainerMonitoring> monitoring;
00107 std::auto_ptr<XMLDocument> xml;
00108 std::string trainFileMask;
00109 std::string name;
00110 bool doAutoSave;
00111 bool doCleanup;
00112 bool doMonitoring;
00113
00114 UInt_t randomSeed;
00115 double crossValidation;
00116 };
00117
00118 }
00119
00120 #endif // PhysicsTools_MVATrainer_MVATrainer_h