Go to the documentation of this file.00001 #ifndef PhysicsTools_MVATrainer_MVATrainerLooperImpl_h
00002 #define PhysicsTools_MVATrainer_MVATrainerLooperImpl_h
00003
00004 #include <string>
00005 #include <memory>
00006
00007 #include <boost/shared_ptr.hpp>
00008
00009 #include "FWCore/Utilities/interface/Exception.h"
00010 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00011 #include "FWCore/Framework/interface/ESProducts.h"
00012
00013 #include "PhysicsTools/MVATrainer/interface/MVATrainerLooper.h"
00014 #include "PhysicsTools/MVATrainer/interface/MVATrainerContainer.h"
00015
00016 namespace PhysicsTools {
00017
00018 template<class Record_t>
00019 class MVATrainerLooperImpl : public MVATrainerLooper {
00020 public:
00021 MVATrainerLooperImpl(const edm::ParameterSet ¶ms) :
00022 MVATrainerLooper(params)
00023 {
00024 setWhatProduced(this, "trainer");
00025 addTrainer(new Trainer(params));
00026 }
00027
00028 virtual ~MVATrainerLooperImpl() {}
00029
00030 boost::shared_ptr<Calibration::MVAComputer>
00031 produce(const Record_t &record)
00032 { return (*getTrainers().begin())->getCalibration(); }
00033 };
00034
00035 template<class Record_t>
00036 class MVATrainerContainerLooperImpl : public MVATrainerLooper {
00037 public:
00038 enum { kTrainer, kTrained };
00039
00040 MVATrainerContainerLooperImpl(const edm::ParameterSet ¶ms) :
00041 MVATrainerLooper(params)
00042 {
00043 setWhatProduced(this, edm::es::label("trainer", kTrainer)
00044 ("trained", kTrained));
00045
00046 std::vector<edm::ParameterSet> trainers =
00047 params.getParameter<std::vector<edm::ParameterSet> >(
00048 "trainers");
00049
00050 for(std::vector<edm::ParameterSet>::const_iterator iter =
00051 trainers.begin(); iter != trainers.end(); iter++)
00052
00053 addTrainer(new Trainer(*iter));
00054 }
00055
00056 virtual ~MVATrainerContainerLooperImpl() {}
00057
00058 edm::ESProducts<
00059 edm::es::L<Calibration::MVAComputerContainer, kTrainer>,
00060 edm::es::L<Calibration::MVAComputerContainer, kTrained> >
00061 produce(const Record_t &record)
00062 {
00063 boost::shared_ptr<MVATrainerContainer> trainerCalib(
00064 new MVATrainerContainer());
00065 TrainContainer trainedCalib;
00066
00067 bool untrained = false;
00068 for(TrainerContainer::const_iterator iter =
00069 getTrainers().begin();
00070 iter != getTrainers().end(); iter++) {
00071 Trainer *trainer = dynamic_cast<Trainer*>(*iter);
00072 TrainObject calib = trainer->getCalibration();
00073
00074 trainerCalib->addTrainer(trainer->calibrationRecord,
00075 calib);
00076 if (calib) {
00077 untrained = true;
00078 continue;
00079 }
00080
00081 if (!trainedCalib)
00082 trainedCalib = TrainContainer(
00083 new Calibration::MVAComputerContainer);
00084
00085 trainedCalib->add(trainer->calibrationRecord) =
00086 *trainer->getTrainer()->getCalibration();
00087 }
00088
00089 if (untrained)
00090 trainedCalib = TrainContainer(
00091 new UntrainedMVAComputerContainer);
00092
00093 edm::es::L<Calibration::MVAComputerContainer, kTrainer>
00094 trainedESLabel(trainerCalib);
00095
00096 return edm::es::products(trainedESLabel,
00097 edm::es::l<kTrained>(trainedCalib));
00098 }
00099
00100 protected:
00101 class Trainer : public MVATrainerLooper::Trainer {
00102 public:
00103 Trainer(const edm::ParameterSet ¶ms) :
00104 MVATrainerLooper::Trainer(params),
00105 calibrationRecord(params.getParameter<std::string>(
00106 "calibrationRecord")) {}
00107
00108 std::string calibrationRecord;
00109 };
00110 };
00111
00112 }
00113
00114 #endif // PhysicsTools_MVATrainer_MVATrainerLooperImpl_h