CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
TrackMVAClassifier.h
Go to the documentation of this file.
1 #ifndef RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
2 #define RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
3 
4 
8 
9 
15 
17 
19 
20 #include <vector>
21 #include <memory>
22 
24 public:
25  explicit TrackMVAClassifierBase( const edm::ParameterSet & cfg );
27 protected:
28 
29  static void fill( edm::ParameterSetDescription& desc);
30 
31 
32  using MVACollection = std::vector<float>;
33  using QualityMaskCollection = std::vector<unsigned char>;
34 
35  virtual void computeMVA(reco::TrackCollection const & tracks,
36  reco::BeamSpot const & beamSpot,
38  GBRForest const * forestP,
39  MVACollection & mvas) const = 0;
40 
41 
42 private:
43 
44  void beginStream(edm::StreamID) override final;
45 
46  void produce(edm::Event& evt, const edm::EventSetup& es ) override final;
47 
49  edm::EDGetTokenT<reco::TrackCollection> src_;
50  edm::EDGetTokenT<reco::BeamSpot> beamspot_;
51  edm::EDGetTokenT<reco::VertexCollection> vertices_;
52 
53 
54  // MVA
55  std::unique_ptr<GBRForest> forest_;
56  const std::string forestLabel_;
57  const std::string dbFileName_;
59 
60  // qualitycuts (loose, tight, hp)
61  float qualityCuts[3];
62 
63 };
64 
65 template<typename MVA>
67 public:
68  explicit TrackMVAClassifier( const edm::ParameterSet & cfg ) :
69  TrackMVAClassifierBase(cfg),
70  mva(cfg.getParameter<edm::ParameterSet>("mva")){}
71 
72  static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
74  fill(desc);
76  MVA::fillDescriptions(mvaDesc);
77  desc.add<edm::ParameterSetDescription>("mva",mvaDesc);
78  descriptions.add(MVA::name(), desc);
79  }
80 
81 
82 private:
84  reco::BeamSpot const & beamSpot,
86  GBRForest const * forestP,
87  MVACollection & mvas) const final {
88 
89  size_t current = 0;
90  for (auto const & trk : tracks) {
91  mvas[current++]= mva(trk,beamSpot,vertices,forestP);
92  }
93  }
94 
95  MVA mva;
96 };
97 
98 
99 
100 #endif // RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
101 
TrackMVAClassifierBase(const edm::ParameterSet &cfg)
tuple cfg
Definition: looper.py:293
virtual void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, GBRForest const *forestP, MVACollection &mvas) const =0
edm::EDGetTokenT< reco::TrackCollection > src_
source collection label
const std::string forestLabel_
const std::string dbFileName_
std::vector< Track > TrackCollection
collection of Tracks
Definition: TrackFwd.h:14
static void fill(edm::ParameterSetDescription &desc)
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
std::vector< float > MVACollection
Definition: Event.h:16
void produce(edm::Event &evt, const edm::EventSetup &es) overridefinal
void beginStream(edm::StreamID) overridefinal
edm::EDGetTokenT< reco::BeamSpot > beamspot_
edm::EDGetTokenT< reco::VertexCollection > vertices_
std::unique_ptr< GBRForest > forest_
ParameterDescriptionBase * add(U const &iLabel, T const &value)
void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, GBRForest const *forestP, MVACollection &mvas) const final
tuple tracks
Definition: testEve_cfg.py:39
string const
Definition: compareJSON.py:14
void add(std::string const &label, ParameterSetDescription const &psetDescription)
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
std::vector< unsigned char > QualityMaskCollection
TrackMVAClassifier(const edm::ParameterSet &cfg)
def template
Definition: svgfig.py:520