CMS 3D CMS Logo

MuonMvaIDEstimator.cc
Go to the documentation of this file.
9 
10 using namespace pat;
11 using namespace cms::Ort;
12 
14  randomForest_ = std::make_unique<ONNXRuntime>(weightsfile.fullPath());
15  LogDebug("MuonMvaIDEstimator") << randomForest_.get();
16 }
17 
20  desc.add<edm::FileInPath>("mvaIDTrainingFile", edm::FileInPath("RecoMuon/MuonIdentification/data/mvaID.onnx"));
21  desc.add<std::vector<std::string>>("flav_names",
22  std::vector<std::string>{
23  "probBAD",
24  "probGOOD",
25  });
26 
27  descriptions.addWithDefaultLabel(desc);
28 }
29 
32 std::vector<float> MuonMvaIDEstimator::computeMVAID(const pat::Muon &muon) const {
33  const float local_chi2 = muon.combinedQuality().chi2LocalPosition;
34  const float kink = muon.combinedQuality().trkKink;
35  const float segment_comp = muon.segmentCompatibility(arbitrationType);
36  const float n_MatchedStations = muon.numberOfMatchedStations();
37  const float pt = muon.pt();
38  const float eta = muon.eta();
39  const float global_muon = muon.isGlobalMuon();
40  float Valid_pixel;
41  float tracker_layers;
42  float validFraction;
43  if (muon.innerTrack().isNonnull()) {
44  Valid_pixel = muon.innerTrack()->hitPattern().numberOfValidPixelHits();
45  tracker_layers = muon.innerTrack()->hitPattern().trackerLayersWithMeasurement();
46  validFraction = muon.innerTrack()->validFraction();
47  } else {
48  Valid_pixel = -99.;
49  tracker_layers = -99.0;
50  validFraction = -99.0;
51  }
52  float norm_chi2;
53  float n_Valid_hits;
54  if (muon.globalTrack().isNonnull()) {
55  norm_chi2 = muon.globalTrack()->normalizedChi2();
56  n_Valid_hits = muon.globalTrack()->hitPattern().numberOfValidMuonHits();
57  } else if (muon.innerTrack().isNonnull()) {
58  norm_chi2 = muon.innerTrack()->normalizedChi2();
59  n_Valid_hits = muon.innerTrack()->hitPattern().numberOfValidMuonHits();
60  } else {
61  norm_chi2 = -99;
62  n_Valid_hits = -99;
63  }
64  const std::vector<std::string> input_names_{"float_input"};
65  std::vector<float> vars = {global_muon,
66  validFraction,
67  norm_chi2,
68  local_chi2,
69  kink,
70  segment_comp,
71  n_Valid_hits,
72  n_MatchedStations,
73  Valid_pixel,
74  tracker_layers,
75  pt,
76  eta};
77  const std::vector<std::string> flav_names_{"probBAD", "probGOOD"};
78  cms::Ort::FloatArrays input_values_;
79  input_values_.emplace_back(vars);
80  std::vector<float> outputs;
81  LogDebug("MuonMvaIDEstimator") << randomForest_.get();
82  outputs = randomForest_->run(input_names_, input_values_, {}, {"probabilities"})[0];
83  assert(outputs.size() == flav_names_.size());
84  return outputs;
85 }
MuonMvaIDEstimator(const edm::FileInPath &weightsfile)
void addWithDefaultLabel(ParameterSetDescription const &psetDescription)
std::string fullPath() const
Definition: FileInPath.cc:161
static void globalEndJob(const cms::Ort::ONNXRuntime *)
std::vector< std::vector< float > > FloatArrays
Definition: ONNXRuntime.h:23
static void fillDescriptions(edm::ConfigurationDescriptions &)
assert(be >=bs)
std::vector< float > computeMVAID(const pat::Muon &imuon) const
Definition: HeavyIon.h:7
ArbitrationType
define arbitration schemes
Definition: Muon.h:187
const reco::Muon::ArbitrationType arbitrationType
def cache(function)
Definition: utilities.py:3
vars
Definition: DeepTauId.cc:166
Analysis-level muon class.
Definition: Muon.h:51
#define LogDebug(id)