CMS 3D CMS Logo

TrackMVAClassifier.h
Go to the documentation of this file.
1 #ifndef RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
2 #define RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
3 
7 
13 
15 
17 
18 #include <vector>
19 #include <memory>
20 
22 public:
24  ~TrackMVAClassifierBase() override;
25 
26  using MVACollection = std::vector<float>;
27  using QualityMaskCollection = std::vector<unsigned char>;
28 
29  //Collection with pairs <MVAOutput, isReliable>
30  using MVAPairCollection = std::vector<std::pair<float, bool>>;
31 
32 protected:
33  static void fill(edm::ParameterSetDescription& desc);
34 
35  virtual void initEvent(const edm::EventSetup& es) = 0;
36 
37  virtual void computeMVA(reco::TrackCollection const& tracks,
38  reco::BeamSpot const& beamSpot,
40  MVAPairCollection& mvas) const = 0;
41 
42 private:
43  void produce(edm::Event& evt, const edm::EventSetup& es) final;
44 
49 
51 
52  // MVA
53 
54  // qualitycuts (loose, tight, hp)
55  float qualityCuts[3];
56 };
57 
59  template <typename EventCache>
60  struct ComputeMVA {
61  template <typename MVA>
62  void operator()(MVA const& mva,
64  reco::BeamSpot const& beamSpot,
67  EventCache cache;
68 
69  size_t current = 0;
70  for (auto const& trk : tracks) {
71  mvas[current++] = mva(trk, beamSpot, vertices, cache);
72  }
73  }
74  };
75 
76  template <>
77  struct ComputeMVA<void> {
78  template <typename MVA>
79  void operator()(MVA const& mva,
81  reco::BeamSpot const& beamSpot,
84  size_t current = 0;
85  for (auto const& trk : tracks) {
86  //BDT outputs are considered always reliable. Hence "true"
87  std::pair<float, bool> output(mva(trk, beamSpot, vertices), true);
88  mvas[current++] = output;
89  }
90  }
91  };
92 } // namespace trackMVAClassifierImpl
93 
94 template <typename MVA, typename EventCache = void>
96 public:
98  : TrackMVAClassifierBase(cfg), mva(cfg.getParameter<edm::ParameterSet>("mva")) {}
99 
102  fill(desc);
104  MVA::fillDescriptions(mvaDesc);
105  desc.add<edm::ParameterSetDescription>("mva", mvaDesc);
106  descriptions.add(MVA::name(), desc);
107  }
108 
109 private:
110  void beginStream(edm::StreamID) final { mva.beginStream(); }
111 
112  void initEvent(const edm::EventSetup& es) final { mva.initEvent(es); }
113 
115  reco::BeamSpot const& beamSpot,
117  MVAPairCollection& mvas) const final {
119  computer(mva, tracks, beamSpot, vertices, mvas);
120  }
121 
122  MVA mva;
123 };
124 
125 #endif // RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
ConfigurationDescriptions.h
edm::StreamID
Definition: StreamID.h:30
PDWG_EXOHSCP_cff.tracks
tracks
Definition: PDWG_EXOHSCP_cff.py:28
edm::ParameterSetDescription::add
ParameterDescriptionBase * add(U const &iLabel, T const &value)
Definition: ParameterSetDescription.h:95
pwdgSkimBPark_cfi.beamSpot
beamSpot
Definition: pwdgSkimBPark_cfi.py:5
trackMVAClassifierImpl::ComputeMVA::operator()
void operator()(MVA const &mva, reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, TrackMVAClassifierBase::MVAPairCollection &mvas)
Definition: TrackMVAClassifier.h:62
TrackMVAClassifier::TrackMVAClassifier
TrackMVAClassifier(const edm::ParameterSet &cfg)
Definition: TrackMVAClassifier.h:97
convertSQLitetoXML_cfg.output
output
Definition: convertSQLitetoXML_cfg.py:32
edm::EDGetTokenT< reco::TrackCollection >
edm
HLT enums.
Definition: AlignableModifier.h:19
reco::VertexCollection
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
edm::ParameterSetDescription
Definition: ParameterSetDescription.h:52
EDProducer.h
TrackMVAClassifierBase
Definition: TrackMVAClassifier.h:21
beam_dqm_sourceclient-live_cfg.mva
mva
Definition: beam_dqm_sourceclient-live_cfg.py:116
GBRForest.h
TrackMVAClassifier::initEvent
void initEvent(const edm::EventSetup &es) final
Definition: TrackMVAClassifier.h:112
TrackMVAClassifierBase::QualityMaskCollection
std::vector< unsigned char > QualityMaskCollection
Definition: TrackMVAClassifier.h:27
trackMVAClassifierImpl::ComputeMVA< void >::operator()
void operator()(MVA const &mva, reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, TrackMVAClassifierBase::MVAPairCollection &mvas)
Definition: TrackMVAClassifier.h:79
TrackMVAClassifier
Definition: TrackMVAClassifier.h:95
TrackMVAClassifierBase::ignoreVertices_
bool ignoreVertices_
Definition: TrackMVAClassifier.h:50
TrackFwd.h
BeamSpot.h
edm::ConfigurationDescriptions::add
void add(std::string const &label, ParameterSetDescription const &psetDescription)
Definition: ConfigurationDescriptions.cc:57
fillDescriptions
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
reco::BeamSpot
Definition: BeamSpot.h:21
TrackMVAClassifierBase::initEvent
virtual void initEvent(const edm::EventSetup &es)=0
TrackMVAClassifierBase::qualityCuts
float qualityCuts[3]
Definition: TrackMVAClassifier.h:55
ParameterSetDescription.h
utilities.cache
def cache(function)
Definition: utilities.py:3
TrackMVAClassifier::fillDescriptions
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
Definition: TrackMVAClassifier.h:100
edm::ConfigurationDescriptions
Definition: ConfigurationDescriptions.h:28
TrackMVAClassifierBase::TrackMVAClassifierBase
TrackMVAClassifierBase(const edm::ParameterSet &cfg)
Definition: TrackMVAClassifierBase.cc:22
edm::ParameterSet
Definition: ParameterSet.h:36
Event.h
ParameterSet
Definition: Functions.h:16
TrackMVAClassifierBase::fill
static void fill(edm::ParameterSetDescription &desc)
Definition: TrackMVAClassifierBase.cc:10
TrackMVAClassifier::mva
MVA mva
Definition: TrackMVAClassifier.h:122
edm::stream::EDProducer
Definition: EDProducer.h:38
edm::EventSetup
Definition: EventSetup.h:57
trackMVAClassifierImpl
Definition: TrackMVAClassifier.h:58
InputTag.h
looper.cfg
cfg
Definition: looper.py:297
TrackMVAClassifierBase::MVAPairCollection
std::vector< std::pair< float, bool > > MVAPairCollection
Definition: TrackMVAClassifier.h:30
VertexFwd.h
TrackMVAClassifierBase::computeMVA
virtual void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, MVAPairCollection &mvas) const =0
TrackMVAClassifierBase::MVACollection
std::vector< float > MVACollection
Definition: TrackMVAClassifier.h:26
TrackMVAClassifierBase::beamspot_
edm::EDGetTokenT< reco::BeamSpot > beamspot_
Definition: TrackMVAClassifier.h:47
TrackMVAClassifierBase::src_
edm::EDGetTokenT< reco::TrackCollection > src_
source collection label
Definition: TrackMVAClassifier.h:46
TrackMVAClassifierBase::~TrackMVAClassifierBase
~TrackMVAClassifierBase() override
Definition: TrackMVAClassifierBase.cc:20
HLT_2018_cff.computer
computer
Definition: HLT_2018_cff.py:50520
Skims_PA_cff.name
name
Definition: Skims_PA_cff.py:17
funct::void
TEMPL(T2) struct Divides void
Definition: Factorize.h:29
ParameterSet.h
trackMVAClassifierImpl::ComputeMVA
Definition: TrackMVAClassifier.h:60
TrackMVAClassifierBase::produce
void produce(edm::Event &evt, const edm::EventSetup &es) final
Definition: TrackMVAClassifierBase.cc:35
TrackMVAClassifier::computeMVA
void computeMVA(reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, MVAPairCollection &mvas) const final
Definition: TrackMVAClassifier.h:114
TrackMVAClassifierBase::vertices_
edm::EDGetTokenT< reco::VertexCollection > vertices_
Definition: TrackMVAClassifier.h:48
edm::Event
Definition: Event.h:73
reco::TrackCollection
std::vector< Track > TrackCollection
collection of Tracks
Definition: TrackFwd.h:14
TrackMVAClassifier::beginStream
void beginStream(edm::StreamID) final
Definition: TrackMVAClassifier.h:110
pwdgSkimBPark_cfi.vertices
vertices
Definition: pwdgSkimBPark_cfi.py:7