Go to the documentation of this file.00001 #ifndef PhysicsTools_MVATrainer_TrainProcessor_h
00002 #define PhysicsTools_MVATrainer_TrainProcessor_h
00003
00004 #include <vector>
00005 #include <string>
00006
00007 #include <boost/version.hpp>
00008 #include <boost/filesystem.hpp>
00009
00010 #include <xercesc/dom/DOM.hpp>
00011
00012 #include "FWCore/PluginManager/interface/PluginFactory.h"
00013
00014 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00015 #include "PhysicsTools/MVAComputer/interface/Variable.h"
00016 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00017 #include "PhysicsTools/MVAComputer/interface/ProcessRegistry.h"
00018
00019 #include "PhysicsTools/MVATrainer/interface/Source.h"
00020 #include "PhysicsTools/MVATrainer/interface/TrainerMonitoring.h"
00021
00022 class TH1F;
00023
00024 namespace PhysicsTools {
00025
00026 class MVATrainer;
00027
00028 class TrainProcessor : public Source,
00029 public ProcessRegistry<TrainProcessor, AtomicId, MVATrainer>::Factory {
00030 public:
00031 template<typename Instance_t>
00032 struct Registry {
00033 typedef typename ProcessRegistry<
00034 TrainProcessor,
00035 AtomicId,
00036 MVATrainer
00037 >::Registry<Instance_t, AtomicId> Type;
00038 };
00039
00040 typedef TrainerMonitoring::Module Monitoring;
00041
00042 TrainProcessor(const char *name,
00043 const AtomicId *id,
00044 MVATrainer *trainer);
00045 virtual ~TrainProcessor();
00046
00047 virtual Variable::Flags getDefaultFlags() const
00048 { return Variable::FLAG_ALL; }
00049
00050 virtual void
00051 configure(XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *config) {}
00052
00053 virtual void
00054 passFlags(const std::vector<Variable::Flags> &flags) {}
00055
00056 virtual Calibration::VarProcessor *getCalibration() const { return 0; }
00057
00058 void doTrainBegin();
00059 void doTrainData(const std::vector<double> *values,
00060 bool target, double weight, bool train, bool test);
00061 void doTrainEnd();
00062
00063 virtual bool load() { return true; }
00064 virtual void save() {}
00065 virtual void cleanup() {}
00066
00067 inline const char *getId() const { return name.c_str(); }
00068
00069 struct Dummy {};
00070 typedef edmplugin::PluginFactory<Dummy*()> PluginFactory;
00071
00072 protected:
00073 virtual void trainBegin() {}
00074 virtual void trainData(const std::vector<double> *values,
00075 bool target, double weight) {}
00076 virtual void testData(const std::vector<double> *values,
00077 bool target, double weight, bool trainedOn) {}
00078 virtual void trainEnd() { trained = true; }
00079
00080 virtual void *requestObject(const std::string &name) const
00081 { return 0; }
00082
00083 inline bool exists(const std::string &name)
00084 { return boost::filesystem::exists(name.c_str()); }
00085
00086 std::string name;
00087 MVATrainer *trainer;
00088 Monitoring *monitoring;
00089
00090 private:
00091 struct SigBkg {
00092 bool sameBinning;
00093 double min;
00094 double max;
00095 unsigned long entries[2];
00096 double underflow[2];
00097 double overflow[2];
00098 TH1F *histo[2];
00099 };
00100
00101 std::vector<SigBkg> monHistos;
00102 Monitoring *monModule;
00103 };
00104
00105 template<>
00106 TrainProcessor *ProcessRegistry<TrainProcessor, AtomicId,
00107 MVATrainer>::Factory::create(
00108 const char*, const AtomicId*, MVATrainer*);
00109
00110 }
00111
00112 #define MVA_TRAINER_DEFINE_PLUGIN(T) \
00113 DEFINE_EDM_PLUGIN(::PhysicsTools::TrainProcessor::PluginFactory, \
00114 ::PhysicsTools::TrainProcessor::Dummy, \
00115 "TrainProcessor/" #T)
00116
00117 #endif // PhysicsTools_MVATrainer_TrainProcessor_h