CMS 3D CMS Logo

TrackMVAClassifier.h
Go to the documentation of this file.
1 #ifndef RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
2 #define RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
3 
7 
13 
15 
17 
18 #include <vector>
19 #include <memory>
20 
22 public:
24  ~TrackMVAClassifierBase() override;
25 
26  using MVACollection = std::vector<float>;
27  using QualityMaskCollection = std::vector<unsigned char>;
28 
29  //Collection with pairs <MVAOutput, isReliable>
30  using MVAPairCollection = std::vector<std::pair<float, bool>>;
31 
32 protected:
33  static void fill(edm::ParameterSetDescription& desc);
34 
35  virtual void initEvent(const edm::EventSetup& es) = 0;
36 
37  virtual void computeMVA(reco::TrackCollection const& tracks,
38  reco::BeamSpot const& beamSpot,
40  MVAPairCollection& mvas) const = 0;
41 
42 private:
43  void produce(edm::Event& evt, const edm::EventSetup& es) final;
44 
49 
51 
52  // MVA
53 
54  // qualitycuts (loose, tight, hp)
55  float qualityCuts[3];
56 };
57 
59  template <typename EventCache>
60  struct ComputeMVA {
61  template <typename MVA>
62  void operator()(MVA const& mva,
64  reco::BeamSpot const& beamSpot,
67  EventCache cache;
68 
69  size_t current = 0;
70  for (auto const& trk : tracks) {
71  mvas[current++] = mva(trk, beamSpot, vertices, cache);
72  }
73  }
74  };
75 
76  template <>
77  struct ComputeMVA<void> {
78  template <typename MVA>
79  void operator()(MVA const& mva,
81  reco::BeamSpot const& beamSpot,
84  size_t current = 0;
85  for (auto const& trk : tracks) {
86  //BDT outputs are considered always reliable. Hence "true"
87  std::pair<float, bool> output(mva(trk, beamSpot, vertices), true);
88  mvas[current++] = output;
89  }
90  }
91  };
92 } // namespace trackMVAClassifierImpl
93 
94 template <typename MVA, typename EventCache = void>
96 public:
98  : TrackMVAClassifierBase(cfg), mva(cfg.getParameter<edm::ParameterSet>("mva")) {}
99 
102  fill(desc);
104  MVA::fillDescriptions(mvaDesc);
105  desc.add<edm::ParameterSetDescription>("mva", mvaDesc);
106  descriptions.add(MVA::name(), desc);
107  }
108 
109 private:
110  void beginStream(edm::StreamID) final { mva.beginStream(); }
111 
112  void initEvent(const edm::EventSetup& es) final { mva.initEvent(es); }
113 
115  reco::BeamSpot const& beamSpot,
117  MVAPairCollection& mvas) const final {
119  computer(mva, tracks, beamSpot, vertices, mvas);
120  }
121 
122  MVA mva;
123 };
124 
125 #endif // RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, MVAPairCollection &mvas) const final
std::vector< std::pair< float, bool >> MVAPairCollection
TrackMVAClassifierBase(const edm::ParameterSet &cfg)
virtual void initEvent(const edm::EventSetup &es)=0
edm::EDGetTokenT< reco::TrackCollection > src_
source collection label
std::vector< Track > TrackCollection
collection of Tracks
Definition: TrackFwd.h:14
static void fill(edm::ParameterSetDescription &desc)
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
std::vector< float > MVACollection
void operator()(MVA const &mva, reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, TrackMVAClassifierBase::MVAPairCollection &mvas)
edm::EDGetTokenT< reco::BeamSpot > beamspot_
void produce(edm::Event &evt, const edm::EventSetup &es) final
void beginStream(edm::StreamID) final
edm::EDGetTokenT< reco::VertexCollection > vertices_
ParameterDescriptionBase * add(U const &iLabel, T const &value)
TrackMVAClassifier(const edm::ParameterSet &cfg)
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
void operator()(MVA const &mva, reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, TrackMVAClassifierBase::MVAPairCollection &mvas)
void initEvent(const edm::EventSetup &es) final
void add(std::string const &label, ParameterSetDescription const &psetDescription)
HLT enums.
def cache(function)
Definition: utilities.py:3
virtual void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, MVAPairCollection &mvas) const =0
std::vector< unsigned char > QualityMaskCollection