CMS 3D CMS Logo

TrainProcessor.h
Go to the documentation of this file.
1 #ifndef PhysicsTools_MVATrainer_TrainProcessor_h
2 #define PhysicsTools_MVATrainer_TrainProcessor_h
3 
4 #include <vector>
5 #include <string>
6 
7 #include <boost/version.hpp>
8 #include <boost/filesystem.hpp>
9 
10 #include <xercesc/dom/DOM.hpp>
11 
13 
18 
21 
22 class TH1F;
23 
24 namespace PhysicsTools {
25 
26 class MVATrainer;
27 
28 class TrainProcessor : public Source,
29  public ProcessRegistry<TrainProcessor, AtomicId, MVATrainer>::Factory {
30  public:
31  template<typename Instance_t>
32  struct Registry {
33  typedef typename ProcessRegistry<
35  AtomicId,
38  };
39 
41 
42  TrainProcessor(const char *name,
43  const AtomicId *id,
45  ~TrainProcessor() override;
46 
48  { return Variable::FLAG_ALL; }
49 
50  virtual void
52 
53  virtual void
54  passFlags(const std::vector<Variable::Flags> &flags) {}
55 
56  virtual Calibration::VarProcessor *getCalibration() const { return nullptr; }
57 
58  void doTrainBegin();
59  void doTrainData(const std::vector<double> *values,
60  bool target, double weight, bool train, bool test);
61  void doTrainEnd();
62 
63  virtual bool load() { return true; }
64  virtual void save() {}
65  virtual void cleanup() {}
66 
67  inline const char *getId() const { return name.c_str(); }
68 
69  struct Dummy {};
71 
72  protected:
73  virtual void trainBegin() {}
74  virtual void trainData(const std::vector<double> *values,
75  bool target, double weight) {}
76  virtual void testData(const std::vector<double> *values,
77  bool target, double weight, bool trainedOn) {}
78  virtual void trainEnd() { trained = true; }
79 
80  virtual void *requestObject(const std::string &name) const
81  { return nullptr; }
82 
83  inline bool exists(const std::string &name)
84  { return boost::filesystem::exists(name.c_str()); }
85 
88  Monitoring *monitoring;
89 
90  private:
91  struct SigBkg {
93  double min;
94  double max;
95  unsigned long entries[2];
96  double underflow[2];
97  double overflow[2];
98  TH1F *histo[2];
99  };
100 
101  std::vector<SigBkg> monHistos;
102  Monitoring *monModule;
103 };
104 
105 template<>
108  const char*, const AtomicId*, MVATrainer*);
109 
110 } // namespace PhysicsTools
111 
112 #define MVA_TRAINER_DEFINE_PLUGIN(T) \
113  DEFINE_EDM_PLUGIN(::PhysicsTools::TrainProcessor::PluginFactory, \
114  ::PhysicsTools::TrainProcessor::Dummy, \
115  "TrainProcessor/" #T)
116 
117 #endif // PhysicsTools_MVATrainer_TrainProcessor_h
const char * getId() const
def create(alignables, pedeDump, additionalData, outputFile, config)
std::vector< Variable::Flags > flags
Definition: MVATrainer.cc:135
Definition: weight.py:1
TrainProcessor(const char *name, const AtomicId *id, MVATrainer *trainer)
Definition: config.py:1
#define XERCES_CPP_NAMESPACE_QUALIFIER
Definition: LHERunInfo.h:16
virtual Calibration::VarProcessor * getCalibration() const
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:31
virtual void * requestObject(const std::string &name) const
virtual void passFlags(const std::vector< Variable::Flags > &flags)
TrainerMonitoring::Module Monitoring
virtual void testData(const std::vector< double > *values, bool target, double weight, bool trainedOn)
virtual void trainData(const std::vector< double > *values, bool target, double weight)
bool exists(const std::string &name)
ProcessRegistry< TrainProcessor, AtomicId, MVATrainer >::Registry< Instance_t, AtomicId > Type
virtual Variable::Flags getDefaultFlags() const
std::vector< SigBkg > monHistos
edmplugin::PluginFactory< Dummy *()> PluginFactory
Generic registry template for polymorphic processor implementations.
void doTrainData(const std::vector< double > *values, bool target, double weight, bool train, bool test)
virtual void configure(XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *config)