CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
MTDTrackQualityMVA.cc
Go to the documentation of this file.
2 
4  std::string options("!Color:Silent");
5  std::string method("BDT");
6 
8  int nvars = sizeof(vars_array) / sizeof(vars_array[0]);
9  vars_.assign(vars_array, vars_array + nvars);
10 
11  mva_ = std::make_unique<TMVAEvaluator>();
12  mva_->initialize(options, method, weights_file, vars_, spec_vars_, true, false); //use GBR, GradBoost
13 }
14 
16  const edm::ValueMap<int>& npixBarrels,
17  const edm::ValueMap<int>& npixEndcaps,
18  const edm::ValueMap<float>& btl_chi2s,
19  const edm::ValueMap<float>& btl_time_chi2s,
20  const edm::ValueMap<float>& etl_chi2s,
21  const edm::ValueMap<float>& etl_time_chi2s,
22  const edm::ValueMap<float>& tmtds,
23  const edm::ValueMap<float>& trk_lengths) const {
24  std::map<std::string, float> vars;
25 
26  //---training performed only above 0.5 GeV
27  constexpr float minPtForMVA = 0.5;
28  if (trk->pt() < minPtForMVA)
29  return -1;
30 
31  //---training performed only for tracks with MTD hits
32  if (tmtds[trk] > 0) {
33  vars.emplace(vars_[int(VarID::pt)], trk->pt());
34  vars.emplace(vars_[int(VarID::eta)], trk->eta());
35  vars.emplace(vars_[int(VarID::phi)], trk->phi());
36  vars.emplace(vars_[int(VarID::chi2)], trk->chi2());
37  vars.emplace(vars_[int(VarID::ndof)], trk->ndof());
38  vars.emplace(vars_[int(VarID::numberOfValidHits)], trk->numberOfValidHits());
39  vars.emplace(vars_[int(VarID::numberOfValidPixelBarrelHits)], npixBarrels[trk]);
40  vars.emplace(vars_[int(VarID::numberOfValidPixelEndcapHits)], npixEndcaps[trk]);
41  vars.emplace(vars_[int(VarID::btlMatchChi2)], btl_chi2s.contains(trk.id()) ? btl_chi2s[trk] : -1);
42  vars.emplace(vars_[int(VarID::btlMatchTimeChi2)], btl_time_chi2s.contains(trk.id()) ? btl_time_chi2s[trk] : -1);
43  vars.emplace(vars_[int(VarID::etlMatchChi2)], etl_chi2s.contains(trk.id()) ? etl_chi2s[trk] : -1);
44  vars.emplace(vars_[int(VarID::etlMatchTimeChi2)], etl_time_chi2s.contains(trk.id()) ? etl_time_chi2s[trk] : -1);
45  vars.emplace(vars_[int(VarID::mtdt)], tmtds[trk]);
46  vars.emplace(vars_[int(VarID::path_len)], trk_lengths[trk]);
47  return 1. / (1 + sqrt(2 / (1 + mva_->evaluate(vars, false)) - 1)); //return values between 0-1 (probability)
48  } else
49  return -1;
50 }
#define MTDBDTVAR_STRING(STRING)
bool contains(ProductID id) const
Definition: ValueMap.h:155
float operator()(const reco::TrackRef &trk, const edm::ValueMap< int > &npixBarrels, const edm::ValueMap< int > &npixEndcaps, const edm::ValueMap< float > &btl_chi2s, const edm::ValueMap< float > &btl_time_chi2s, const edm::ValueMap< float > &etl_chi2s, const edm::ValueMap< float > &etl_time_chi2s, const edm::ValueMap< float > &tmtds, const edm::ValueMap< float > &trk_lengths) const
#define MTDTRACKQUALITYMVA_VARS(MTDBDTVAR)
T sqrt(T t)
Definition: SSEVec.h:19
std::vector< std::string > spec_vars_
MTDTrackQualityMVA(std::string weights_file)
std::vector< std::string > vars_
vars
Definition: DeepTauId.cc:164
std::unique_ptr< TMVAEvaluator > mva_