CMS 3D CMS Logo

TrackMVAClassifier.h
Go to the documentation of this file.
1 #ifndef RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
2 #define RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
3 
7 
14 
16 
18 
19 #include <vector>
20 #include <memory>
21 
23 public:
25  ~TrackMVAClassifierBase() override;
26 
27  using MVACollection = std::vector<float>;
28  using QualityMaskCollection = std::vector<unsigned char>;
29 
30  //Collection with pairs <MVAOutput, isReliable>
31  using MVAPairCollection = std::vector<std::pair<float, bool>>;
32 
33 protected:
35 
36  virtual void initEvent(const edm::EventSetup& es) = 0;
37 
38  virtual void computeMVA(reco::TrackCollection const& tracks,
39  reco::BeamSpot const& beamSpot,
41  MVAPairCollection& mvas) const = 0;
42 
43 private:
44  void produce(edm::Event& evt, const edm::EventSetup& es) final;
45 
50 
52 
53  // MVA
54 
55  // qualitycuts (loose, tight, hp)
56  float qualityCuts[3];
57 };
58 
60  template <typename EventCache>
61  struct ComputeMVA {
62  template <typename MVA>
63  void operator()(MVA const& mva,
65  reco::BeamSpot const& beamSpot,
68  EventCache cache;
69 
70  size_t current = 0;
71  for (auto const& trk : tracks) {
72  mvas[current++] = mva(trk, beamSpot, vertices, cache);
73  }
74  }
75  };
76 
77  template <>
78  struct ComputeMVA<void> {
79  template <typename MVA>
80  void operator()(MVA const& mva,
82  reco::BeamSpot const& beamSpot,
85  size_t current = 0;
86  for (auto const& trk : tracks) {
87  //BDT outputs are considered always reliable. Hence "true"
88  std::pair<float, bool> output(mva(trk, beamSpot, vertices), true);
89  mvas[current++] = output;
90  }
91  }
92  };
93 } // namespace trackMVAClassifierImpl
94 
95 template <typename MVA, typename EventCache = void>
97 public:
99  : TrackMVAClassifierBase(cfg), mva(cfg.getParameter<edm::ParameterSet>("mva"), consumesCollector()) {}
100 
103  fill(desc);
105  MVA::fillDescriptions(mvaDesc);
106  desc.add<edm::ParameterSetDescription>("mva", mvaDesc);
107  descriptions.add(MVA::name(), desc);
108  }
109 
110 private:
111  void beginStream(edm::StreamID) final { mva.beginStream(); }
112 
113  void initEvent(const edm::EventSetup& es) final { mva.initEvent(es); }
114 
116  reco::BeamSpot const& beamSpot,
118  MVAPairCollection& mvas) const final {
120  computer(mva, tracks, beamSpot, vertices, mvas);
121  }
122 
123  MVA mva;
124 };
125 
126 #endif // RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
TrackMVAClassifierBase(const edm::ParameterSet &cfg)
std::vector< std::pair< float, bool > > MVAPairCollection
virtual void initEvent(const edm::EventSetup &es)=0
edm::EDGetTokenT< reco::TrackCollection > src_
source collection label
virtual void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, MVAPairCollection &mvas) const =0
std::vector< Track > TrackCollection
collection of Tracks
Definition: TrackFwd.h:14
static void fill(edm::ParameterSetDescription &desc)
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
std::vector< float > MVACollection
TEMPL(T2) struct Divides void
Definition: Factorize.h:24
void operator()(MVA const &mva, reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, TrackMVAClassifierBase::MVAPairCollection &mvas)
void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, MVAPairCollection &mvas) const final
edm::EDGetTokenT< reco::BeamSpot > beamspot_
void produce(edm::Event &evt, const edm::EventSetup &es) final
void beginStream(edm::StreamID) final
edm::EDGetTokenT< reco::VertexCollection > vertices_
TrackMVAClassifier(const edm::ParameterSet &cfg)
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
void operator()(MVA const &mva, reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, TrackMVAClassifierBase::MVAPairCollection &mvas)
void initEvent(const edm::EventSetup &es) final
void add(std::string const &label, ParameterSetDescription const &psetDescription)
HLT enums.
def cache(function)
Definition: utilities.py:3
std::vector< unsigned char > QualityMaskCollection