CMS 3D CMS Logo

TrackMVAClassifier.h
Go to the documentation of this file.
1 #ifndef RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
2 #define RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
3 
4 
8 
9 
15 
17 
19 
20 #include <vector>
21 #include <memory>
22 
24 public:
25  explicit TrackMVAClassifierBase( const edm::ParameterSet & cfg );
26  ~TrackMVAClassifierBase() override;
27 
28  using MVACollection = std::vector<float>;
29  using QualityMaskCollection = std::vector<unsigned char>;
30 
31  //Collection with pairs <MVAOutput, isReliable>
32  using MVAPairCollection = std::vector<std::pair<float, bool>>;
33 
34 protected:
35 
36  static void fill( edm::ParameterSetDescription& desc);
37 
38 
39  virtual void initEvent(const edm::EventSetup& es) = 0;
40 
41  virtual void computeMVA(reco::TrackCollection const & tracks,
42  reco::BeamSpot const & beamSpot,
44  MVAPairCollection & mvas) const = 0;
45 
46 private:
47  void produce(edm::Event& evt, const edm::EventSetup& es ) final;
48 
53 
55 
56  // MVA
57 
58  // qualitycuts (loose, tight, hp)
59  float qualityCuts[3];
60 
61 };
62 
64  template<typename EventCache>
65  struct ComputeMVA {
66  template <typename MVA>
67  void operator()(MVA const & mva,
69  reco::BeamSpot const & beamSpot,
72 
73  EventCache cache;
74 
75  size_t current = 0;
76  for (auto const & trk : tracks) {
77  mvas[current++] = mva(trk,beamSpot,vertices,cache);
78  }
79  }
80  };
81 
82  template <>
83  struct ComputeMVA<void> {
84  template <typename MVA>
85  void operator()(MVA const & mva,
87  reco::BeamSpot const & beamSpot,
90 
91  size_t current = 0;
92  for (auto const & trk : tracks) {
93  //BDT outputs are considered always reliable. Hence "true"
94  std::pair<float,bool> output (mva(trk,beamSpot,vertices), true);
95  mvas[current++]= output;
96  }
97  }
98  };
99 }
100 
101 template<typename MVA, typename EventCache=void>
103 public:
106  mva(cfg.getParameter<edm::ParameterSet>("mva")){}
107 
110  fill(desc);
112  MVA::fillDescriptions(mvaDesc);
113  desc.add<edm::ParameterSetDescription>("mva",mvaDesc);
114  descriptions.add(MVA::name(), desc);
115  }
116 
117 
118 private:
120  mva.beginStream();
121  }
122 
123  void initEvent(const edm::EventSetup& es) final {
124  mva.initEvent(es);
125  }
126 
128  reco::BeamSpot const & beamSpot,
130  MVAPairCollection & mvas) const final {
131 
134  }
135 
137 };
138 
139 
140 
141 #endif // RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
142 
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, MVAPairCollection &mvas) const final
std::vector< std::pair< float, bool >> MVAPairCollection
TrackMVAClassifierBase(const edm::ParameterSet &cfg)
virtual void initEvent(const edm::EventSetup &es)=0
edm::EDGetTokenT< reco::TrackCollection > src_
source collection label
std::vector< Track > TrackCollection
collection of Tracks
Definition: TrackFwd.h:15
static void fill(edm::ParameterSetDescription &desc)
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
std::vector< float > MVACollection
void operator()(MVA const &mva, reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, TrackMVAClassifierBase::MVAPairCollection &mvas)
edm::EDGetTokenT< reco::BeamSpot > beamspot_
void produce(edm::Event &evt, const edm::EventSetup &es) final
void beginStream(edm::StreamID) final
edm::EDGetTokenT< reco::VertexCollection > vertices_
ParameterDescriptionBase * add(U const &iLabel, T const &value)
TrackMVAClassifier(const edm::ParameterSet &cfg)
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
void operator()(MVA const &mva, reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, TrackMVAClassifierBase::MVAPairCollection &mvas)
void initEvent(const edm::EventSetup &es) final
void add(std::string const &label, ParameterSetDescription const &psetDescription)
HLT enums.
def cache(function)
Definition: utilities.py:3
virtual void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, MVAPairCollection &mvas) const =0
std::vector< unsigned char > QualityMaskCollection