CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_4_1_8_patch12/src/PhysicsTools/MVATrainer/interface/MVATrainerLooperImpl.h

Go to the documentation of this file.
00001 #ifndef PhysicsTools_MVATrainer_MVATrainerLooperImpl_h
00002 #define PhysicsTools_MVATrainer_MVATrainerLooperImpl_h
00003 
00004 #include <string>
00005 #include <memory>
00006 
00007 #include <boost/shared_ptr.hpp>
00008 
00009 #include "FWCore/Utilities/interface/Exception.h"
00010 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00011 #include "FWCore/Framework/interface/ESProducts.h"
00012 
00013 #include "PhysicsTools/MVATrainer/interface/MVATrainerLooper.h"
00014 #include "PhysicsTools/MVATrainer/interface/MVATrainerContainer.h"
00015 
00016 namespace PhysicsTools {
00017 
00018 template<class Record_t>
00019 class MVATrainerLooperImpl : public MVATrainerLooper {
00020     public:
00021         MVATrainerLooperImpl(const edm::ParameterSet &params) :
00022                 MVATrainerLooper(params)
00023         {
00024                 setWhatProduced(this, "trainer");
00025                 addTrainer(new Trainer(params));
00026         }
00027 
00028         virtual ~MVATrainerLooperImpl() {}
00029 
00030         boost::shared_ptr<Calibration::MVAComputer>
00031         produce(const Record_t &record)
00032         { return (*getTrainers().begin())->getCalibration(); }
00033 };
00034 
00035 template<class Record_t>
00036 class MVATrainerContainerLooperImpl : public MVATrainerLooper {
00037     public:
00038         enum { kTrainer, kTrained };
00039 
00040         MVATrainerContainerLooperImpl(const edm::ParameterSet &params) :
00041                 MVATrainerLooper(params)
00042         {
00043                 setWhatProduced(this, edm::es::label("trainer", kTrainer)
00044                                                     ("trained", kTrained));
00045 
00046                 std::vector<edm::ParameterSet> trainers =
00047                         params.getParameter<std::vector<edm::ParameterSet> >(
00048                                                                 "trainers");
00049 
00050                 for(std::vector<edm::ParameterSet>::const_iterator iter =
00051                         trainers.begin(); iter != trainers.end(); iter++)
00052 
00053                         addTrainer(new Trainer(*iter));
00054         }
00055 
00056         virtual ~MVATrainerContainerLooperImpl() {}
00057 
00058         edm::ESProducts<
00059                 edm::es::L<Calibration::MVAComputerContainer, kTrainer>,
00060                 edm::es::L<Calibration::MVAComputerContainer, kTrained> >
00061         produce(const Record_t &record)
00062         {
00063                 boost::shared_ptr<MVATrainerContainer> trainerCalib(
00064                                                 new MVATrainerContainer());
00065                 TrainContainer trainedCalib;
00066 
00067                 bool untrained = false;
00068                 for(TrainerContainer::const_iterator iter =
00069                                                         getTrainers().begin();
00070                     iter != getTrainers().end(); iter++) {
00071                         Trainer *trainer = dynamic_cast<Trainer*>(*iter);
00072                         TrainObject calib = trainer->getCalibration();
00073 
00074                         trainerCalib->addTrainer(trainer->calibrationRecord,
00075                                                  calib);
00076                         if (calib) {
00077                                 untrained = true;
00078                                 continue;
00079                         }
00080 
00081                         if (!trainedCalib)
00082                                 trainedCalib = TrainContainer(
00083                                         new Calibration::MVAComputerContainer);
00084 
00085                         trainedCalib->add(trainer->calibrationRecord) =
00086                                 *trainer->getTrainer()->getCalibration();
00087                 }
00088 
00089                 if (untrained)
00090                         trainedCalib = TrainContainer(
00091                                         new UntrainedMVAComputerContainer);
00092 
00093                 edm::es::L<Calibration::MVAComputerContainer, kTrainer>
00094                                                 trainedESLabel(trainerCalib);
00095 
00096                 return edm::es::products(trainedESLabel,
00097                                          edm::es::l<kTrained>(trainedCalib));
00098         }
00099 
00100     protected:
00101         class Trainer : public MVATrainerLooper::Trainer {
00102             public:
00103                 Trainer(const edm::ParameterSet &params) :
00104                         MVATrainerLooper::Trainer(params),
00105                         calibrationRecord(params.getParameter<std::string>(
00106                                                 "calibrationRecord")) {}
00107 
00108                 std::string calibrationRecord;
00109         };
00110 };
00111 
00112 } // namespace PhysicsTools
00113 
00114 #endif // PhysicsTools_MVATrainer_MVATrainerLooperImpl_h