CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_5_3_8_patch3/src/PhysicsTools/MVATrainer/interface/TrainProcessor.h

Go to the documentation of this file.
00001 #ifndef PhysicsTools_MVATrainer_TrainProcessor_h
00002 #define PhysicsTools_MVATrainer_TrainProcessor_h
00003 
00004 #include <vector>
00005 #include <string>
00006 
00007 #include <boost/version.hpp>
00008 #include <boost/filesystem.hpp>
00009 
00010 #include <xercesc/dom/DOM.hpp>
00011 
00012 #include "FWCore/PluginManager/interface/PluginFactory.h"
00013 
00014 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00015 #include "PhysicsTools/MVAComputer/interface/Variable.h"
00016 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
00017 #include "PhysicsTools/MVAComputer/interface/ProcessRegistry.h"
00018 
00019 #include "PhysicsTools/MVATrainer/interface/Source.h"
00020 #include "PhysicsTools/MVATrainer/interface/TrainerMonitoring.h"
00021 
00022 class TH1F;
00023 
00024 namespace PhysicsTools {
00025 
00026 class MVATrainer;
00027 
00028 class TrainProcessor : public Source,
00029         public ProcessRegistry<TrainProcessor, AtomicId, MVATrainer>::Factory {
00030     public:
00031         template<typename Instance_t>
00032         struct Registry {
00033                 typedef typename ProcessRegistry<
00034                         TrainProcessor,
00035                         AtomicId,
00036                         MVATrainer
00037                 >::Registry<Instance_t, AtomicId> Type;
00038         };
00039 
00040         typedef TrainerMonitoring::Module Monitoring;
00041 
00042         TrainProcessor(const char *name,
00043                        const AtomicId *id,
00044                        MVATrainer *trainer);
00045         virtual ~TrainProcessor();
00046 
00047         virtual Variable::Flags getDefaultFlags() const
00048         { return Variable::FLAG_ALL; }
00049 
00050         virtual void
00051         configure(XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *config) {}
00052 
00053         virtual void
00054         passFlags(const std::vector<Variable::Flags> &flags) {}
00055 
00056         virtual Calibration::VarProcessor *getCalibration() const { return 0; }
00057 
00058         void doTrainBegin();
00059         void doTrainData(const std::vector<double> *values,
00060                          bool target, double weight, bool train, bool test);
00061         void doTrainEnd();
00062 
00063         virtual bool load() { return true; }
00064         virtual void save() {}
00065         virtual void cleanup() {}
00066 
00067         inline const char *getId() const { return name.c_str(); }
00068 
00069         struct Dummy {};
00070         typedef edmplugin::PluginFactory<Dummy*()> PluginFactory;
00071 
00072     protected:
00073         virtual void trainBegin() {}
00074         virtual void trainData(const std::vector<double> *values,
00075                                bool target, double weight) {}
00076         virtual void testData(const std::vector<double> *values,
00077                               bool target, double weight, bool trainedOn) {}
00078         virtual void trainEnd() { trained = true; }
00079 
00080         virtual void *requestObject(const std::string &name) const
00081         { return 0; }
00082 
00083         inline bool exists(const std::string &name)
00084         { return boost::filesystem::exists(name.c_str()); }
00085 
00086         std::string             name;
00087         MVATrainer              *trainer;
00088         Monitoring              *monitoring;
00089 
00090     private:
00091         struct SigBkg {
00092                 bool            sameBinning;
00093                 double          min;
00094                 double          max;
00095                 unsigned long   entries[2];
00096                 double          underflow[2];
00097                 double          overflow[2];
00098                 TH1F            *histo[2];
00099         };
00100                 
00101         std::vector<SigBkg>     monHistos;
00102         Monitoring              *monModule;
00103 };
00104 
00105 template<>
00106 TrainProcessor *ProcessRegistry<TrainProcessor, AtomicId,
00107                                 MVATrainer>::Factory::create(
00108                         const char*, const AtomicId*, MVATrainer*);
00109 
00110 } // namespace PhysicsTools
00111 
00112 #define MVA_TRAINER_DEFINE_PLUGIN(T) \
00113         DEFINE_EDM_PLUGIN(::PhysicsTools::TrainProcessor::PluginFactory, \
00114                           ::PhysicsTools::TrainProcessor::Dummy, \
00115                           "TrainProcessor/" #T)
00116 
00117 #endif // PhysicsTools_MVATrainer_TrainProcessor_h