CMS 3D CMS Logo

MVATrainer.h
Go to the documentation of this file.
1 #ifndef PhysicsTools_MVATrainer_MVATrainer_h
2 #define PhysicsTools_MVATrainer_MVATrainer_h
3 
4 #include <memory>
5 #include <string>
6 #include <map>
7 
8 #include <xercesc/dom/DOM.hpp>
9 
11 
16 
21 
22 namespace PhysicsTools {
23 
24 class Source;
25 class TrainProcessor;
26 
27 class MVATrainer {
28  public:
29  MVATrainer(const std::string &fileName, bool useXSLT = false,
30  const char *styleSheet = nullptr);
31  ~MVATrainer();
32 
33  inline void setAutoSave(bool autoSave) { doAutoSave = autoSave; }
34  inline void setCleanup(bool cleanup) { doCleanup = cleanup; }
36  inline void setRandomSeed(UInt_t seed) { randomSeed = seed; }
37  inline void setCrossValidation(double split) { crossValidation = split; }
38 
39  void loadState();
40  void saveState();
41 
43  void doneTraining(Calibration::MVAComputer *trainCalibration) const;
44 
46 
47  // used by TrainProcessors
48 
50  const std::string &ext,
51  const std::string &arg = "") const;
52 
53  inline const std::string &getName() const { return name; }
54 
56 
57  // constants
58 
59  static const AtomicId kTargetId;
60  static const AtomicId kWeightId;
61 
65  processor(processor), calib(calib) {}
66 
69  };
70 
71  private:
73 
76 
79 
82 
84  AtomicId id, const char *name);
85 
87  const std::vector<CalibratedProcessor> &procs,
88  bool withTarget) const;
89 
92  const AtomicId *train) const;
93 
94  void
95  findUntrainedComputers(std::vector<AtomicId> &compute,
96  std::vector<AtomicId> &train) const;
97 
98  std::vector<AtomicId> findFinalProcessors() const;
99 
100  std::map<AtomicId, Source*> sources;
101  std::vector<SourceVariable*> variables;
102  std::vector<AtomicId> processors;
105 
106  std::unique_ptr<TrainerMonitoring> monitoring;
107  std::unique_ptr<XMLDocument> xml;
111  bool doCleanup;
113 
114  UInt_t randomSeed;
116 };
117 
118 } // namespace PhysicsTools
119 
120 #endif // PhysicsTools_MVATrainer_MVATrainer_h
const std::string & getName() const
Definition: MVATrainer.h:53
TrainProcessor *const proc
Definition: MVATrainer.cc:101
static void cleanup(const Factory::MakerMap::value_type &v)
Definition: Factory.cc:12
std::vector< Variable::Flags > flags
Definition: MVATrainer.cc:135
SourceVariable * getVariable(AtomicId source, AtomicId name) const
Definition: MVATrainer.cc:723
void setMonitoring(bool monitoring)
Definition: MVATrainer.h:35
#define XERCES_CPP_NAMESPACE_QUALIFIER
Definition: LHERunInfo.h:16
A arg
Definition: Factorize.h:38
static const AtomicId kTargetId
Definition: MVATrainer.h:59
void fillOutputVars(SourceVariableSet &vars, Source *source, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
Definition: MVATrainer.cc:836
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:31
static const AtomicId kWeightId
Definition: MVATrainer.h:60
Calibration::MVAComputer * makeTrainCalibration(const AtomicId *compute, const AtomicId *train) const
Definition: MVATrainer.cc:972
CalibratedProcessor(TrainProcessor *processor, Calibration::VarProcessor *calib)
Definition: MVATrainer.h:63
void setAutoSave(bool autoSave)
Definition: MVATrainer.h:33
std::vector< AtomicId > processors
Definition: MVATrainer.h:102
std::vector< AtomicId > findFinalProcessors() const
Definition: MVATrainer.cc:1113
std::string trainFileName(const TrainProcessor *proc, const std::string &ext, const std::string &arg="") const
Definition: MVATrainer.cc:698
void connectProcessors(Calibration::MVAComputer *calib, const std::vector< CalibratedProcessor > &procs, bool withTarget) const
Definition: MVATrainer.cc:883
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:19
std::map< AtomicId, Source * > sources
Definition: MVATrainer.h:100
void setCleanup(bool cleanup)
Definition: MVATrainer.h:34
void doneTraining(Calibration::MVAComputer *trainCalibration) const
Definition: MVATrainer.cc:1100
Calibration::VarProcessor * calib
Definition: MVATrainer.h:68
std::unique_ptr< TrainerMonitoring > monitoring
Definition: MVATrainer.h:106
TrainerMonitoring::Module * bookMonitor(const std::string &name)
Definition: MVATrainer.cc:708
std::unique_ptr< XMLDocument > xml
Definition: MVATrainer.h:107
Calibration::MVAComputer * getCalibration() const
Definition: MVATrainer.cc:1144
def compute(min, max)
std::vector< SourceVariable * > variables
Definition: MVATrainer.h:101
void setCrossValidation(double split)
Definition: MVATrainer.h:37
void fillInputVars(SourceVariableSet &vars, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
Definition: MVATrainer.cc:744
MVATrainer(const std::string &fileName, bool useXSLT=false, const char *styleSheet=0)
Definition: MVATrainer.cc:420
void setRandomSeed(UInt_t seed)
Definition: MVATrainer.h:36
void makeProcessor(XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *elem, AtomicId id, const char *name)
Definition: MVATrainer.cc:620
static std::vector< std::string > split(const std::string line, char delim)
Definition: MLP.cc:18
std::string trainFileMask
Definition: MVATrainer.h:108
void findUntrainedComputers(std::vector< AtomicId > &compute, std::vector< AtomicId > &train) const
Definition: MVATrainer.cc:1193
TrainProcessor * output
Definition: MVATrainer.h:104
vars
Definition: DeepTauId.cc:77
Calibration::MVAComputer * getTrainCalibration() const
Definition: MVATrainer.cc:1239
Definition: memstream.h:15
static std::string const source
Definition: EdmProvDump.cc:47
SourceVariable * createVariable(Source *source, AtomicId name, Variable::Flags flags)
Definition: MVATrainer.cc:732