CMS 3D CMS Logo

TrackMVAClassifier.h
Go to the documentation of this file.
1 #ifndef RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
2 #define RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
3 
8 
15 
17 
19 
20 #include <vector>
21 #include <memory>
22 
24 public:
26  ~TrackMVAClassifierBase() override;
27 
28  using MVACollection = std::vector<float>;
29  using QualityMaskCollection = std::vector<unsigned char>;
30 
31  //Collection with pairs <MVAOutput, isReliable>
32  using MVAPairCollection = std::vector<std::pair<float, bool>>;
33 
34 protected:
36 
37  virtual void initEvent(const edm::EventSetup& es) = 0;
38 
39  virtual void computeMVA(reco::TrackCollection const& tracks,
40  reco::BeamSpot const& beamSpot,
42  MVAPairCollection& mvas) const = 0;
43 
44 private:
45  void produce(edm::Event& evt, const edm::EventSetup& es) final;
46 
51 
53 
54  // MVA
55 
56  // qualitycuts (loose, tight, hp)
57  float qualityCuts[3];
58 };
59 
61  template <typename EventCache>
62  struct ComputeMVA {
63  template <typename MVA>
64  void operator()(MVA const& mva,
66  reco::BeamSpot const& beamSpot,
69  EventCache cache;
70 
71  size_t current = 0;
72  for (auto const& trk : tracks) {
73  mvas[current++] = mva(trk, beamSpot, vertices, cache);
74  }
75  }
76  };
77 
78  template <>
79  struct ComputeMVA<void> {
80  template <typename MVA>
81  void operator()(MVA const& mva,
83  reco::BeamSpot const& beamSpot,
86  size_t current = 0;
87  for (auto const& trk : tracks) {
88  //BDT outputs are considered always reliable. Hence "true"
89  std::pair<float, bool> output(mva(trk, beamSpot, vertices), true);
90  mvas[current++] = output;
91  }
92  }
93  };
94 } // namespace trackMVAClassifierImpl
95 
96 template <typename MVA, typename EventCache = void>
98 public:
100  : TrackMVAClassifierBase(cfg), mva(cfg.getParameter<edm::ParameterSet>("mva"), consumesCollector()) {}
101 
104  fill(desc);
106  MVA::fillDescriptions(mvaDesc);
107  desc.add<edm::ParameterSetDescription>("mva", mvaDesc);
108  descriptions.add(MVA::name(), desc);
109  }
110 
111 private:
112  void beginStream(edm::StreamID) final { mva.beginStream(); }
113 
114  void initEvent(const edm::EventSetup& es) final { mva.initEvent(es); }
115 
117  reco::BeamSpot const& beamSpot,
119  MVAPairCollection& mvas) const final {
121  computer(mva, tracks, beamSpot, vertices, mvas);
122  }
123 
124  MVA mva;
125 };
126 
127 #endif // RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
TrackMVAClassifierBase(const edm::ParameterSet &cfg)
std::vector< std::pair< float, bool > > MVAPairCollection
virtual void initEvent(const edm::EventSetup &es)=0
edm::EDGetTokenT< reco::TrackCollection > src_
source collection label
virtual void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, MVAPairCollection &mvas) const =0
std::vector< Track > TrackCollection
collection of Tracks
Definition: TrackFwd.h:14
static void fill(edm::ParameterSetDescription &desc)
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
std::vector< float > MVACollection
TEMPL(T2) struct Divides void
Definition: Factorize.h:24
void operator()(MVA const &mva, reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, TrackMVAClassifierBase::MVAPairCollection &mvas)
void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, MVAPairCollection &mvas) const final
edm::EDGetTokenT< reco::BeamSpot > beamspot_
void produce(edm::Event &evt, const edm::EventSetup &es) final
void beginStream(edm::StreamID) final
edm::EDGetTokenT< reco::VertexCollection > vertices_
TrackMVAClassifier(const edm::ParameterSet &cfg)
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
void operator()(MVA const &mva, reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, TrackMVAClassifierBase::MVAPairCollection &mvas)
void initEvent(const edm::EventSetup &es) final
void add(std::string const &label, ParameterSetDescription const &psetDescription)
HLT enums.
def cache(function)
Definition: utilities.py:3
std::vector< unsigned char > QualityMaskCollection