CMS 3D CMS Logo

MVATrainerLooperImpl.h
Go to the documentation of this file.
1 #ifndef PhysicsTools_MVATrainer_MVATrainerLooperImpl_h
2 #define PhysicsTools_MVATrainer_MVATrainerLooperImpl_h
3 
4 #include <string>
5 #include <memory>
6 
10 
13 
14 namespace PhysicsTools {
15 
16 template<class Record_t>
18  public:
20  MVATrainerLooper(params)
21  {
22  setWhatProduced(this, "trainer");
23  addTrainer(new Trainer(params));
24  }
25 
26  ~MVATrainerLooperImpl() override {}
27 
28  std::shared_ptr<Calibration::MVAComputer>
29  produce(const Record_t &record)
30  { return (*getTrainers().begin())->getCalibration(); }
31 };
32 
33 template<class Record_t>
35  public:
36  enum { kTrainer, kTrained };
37 
39  MVATrainerLooper(params)
40  {
41  setWhatProduced(this, edm::es::label("trainer", kTrainer)
42  ("trained", kTrained));
43 
44  std::vector<edm::ParameterSet> trainers =
45  params.getParameter<std::vector<edm::ParameterSet> >(
46  "trainers");
47 
48  for(std::vector<edm::ParameterSet>::const_iterator iter =
49  trainers.begin(); iter != trainers.end(); iter++)
50 
51  addTrainer(new Trainer(*iter));
52  }
53 
55 
59  produce(const Record_t &record)
60  {
61  std::shared_ptr<MVATrainerContainer> trainerCalib(
62  new MVATrainerContainer());
63  TrainContainer trainedCalib;
64 
65  bool untrained = false;
67  getTrainers().begin();
68  iter != getTrainers().end(); iter++) {
69  Trainer *trainer = dynamic_cast<Trainer*>(*iter);
70  TrainObject calib = trainer->getCalibration();
71 
72  trainerCalib->addTrainer(trainer->calibrationRecord,
73  calib);
74  if (calib) {
75  untrained = true;
76  continue;
77  }
78 
79  if (!trainedCalib)
80  trainedCalib = std::make_shared<PhysicsTools::Calibration::MVAComputerContainer>(
81  );
82 
83  trainedCalib->add(trainer->calibrationRecord) =
84  *trainer->getTrainer()->getCalibration();
85  }
86 
87  if (untrained)
88  trainedCalib = TrainContainer(
90 
91  edm::es::L<Calibration::MVAComputerContainer, kTrainer>
92  trainedESLabel(trainerCalib);
93 
94  return edm::es::products(trainedESLabel,
95  edm::es::l<kTrained>(trainedCalib));
96  }
97 
98  protected:
100  public:
101  Trainer(const edm::ParameterSet &params) :
102  MVATrainerLooper::Trainer(params),
103  calibrationRecord(params.getParameter<std::string>(
104  "calibrationRecord")) {}
105 
107  };
108 };
109 
110 } // namespace PhysicsTools
111 
112 #endif // PhysicsTools_MVATrainer_MVATrainerLooperImpl_h
T getParameter(std::string const &) const
auto setWhatProduced(T *iThis, const es::Label &iLabel={})
Definition: ESProducer.h:116
Label label(const std::string &iString, int iIndex)
Definition: es_Label.h:97
void addTrainer(Trainer *trainer)
JetCorrectorParameters::Record record
Definition: classes.h:7
ESProducts< std::remove_reference_t< TArgs >... > products(TArgs &&...args)
Definition: ESProducts.h:129
edm::ESProducts< edm::es::L< Calibration::MVAComputerContainer, kTrainer >, edm::es::L< Calibration::MVAComputerContainer, kTrained > > produce(const Record_t &record)
std::shared_ptr< Calibration::MVAComputer > produce(const Record_t &record)
const TrainerContainer & getTrainers() const
std::shared_ptr< Calibration::MVAComputerContainer > TrainContainer
const TrainObject getCalibration() const
Calibration::MVAComputer * getCalibration() const
Definition: MVATrainer.cc:1144
std::vector< Trainer * >::const_iterator const_iterator
#define begin
Definition: vmac.h:32
std::shared_ptr< Calibration::MVAComputer > TrainObject
const MVATrainer * getTrainer() const
MVATrainerContainerLooperImpl(const edm::ParameterSet &params)
MVATrainerLooperImpl(const edm::ParameterSet &params)