CMS 3D CMS Logo

dataLoader.py
Go to the documentation of this file.
1 from __future__ import print_function
2 import itertools
3 import ROOT
4 try:
5  ROOT.BTagEntry
6 except AttributeError:
7  ROOT.gROOT.ProcessLine('.L BTagCalibrationStandalone.cpp+')
8 
9 try:
10  ROOT.BTagEntry
11 except AttributeError:
12  print('ROOT.BTagEntry is needed! Please copy ' \
13  'BTagCalibrationStandalone.[h|cpp] to the working directory. Exit.')
14  exit(-1)
15 
16 separate_by_op = False
17 separate_by_flav = False
18 
19 
21  def __init__(self, csv_data, measurement_type, operating_point, flavour):
22  self.meas_type = measurement_type
23  self.op = operating_point
24  self.flav = flavour
25 
26  # list of entries
27  ens = []
28  for l in csv_data:
29  if not l.strip():
30  continue # skip empty lines
31  try:
32  e = ROOT.BTagEntry(l)
33  if (e.params.measurementType == measurement_type
34  and ((not separate_by_op)
35  or e.params.operatingPoint == operating_point)
36  and ((not separate_by_flav)
37  or e.params.jetFlavor == flavour)
38  ):
39  ens.append(e)
40  except TypeError:
41  raise RuntimeError("Error: can not interpret line: " + l)
42  self.entries = ens
43 
44  if not ens:
45  return
46 
47  # fixed data
48  self.ops = set(e.params.operatingPoint for e in ens)
49  self.flavs = set(e.params.jetFlavor for e in ens)
50  self.syss = set(e.params.sysType for e in ens)
51  self.etas = set((e.params.etaMin, e.params.etaMax) for e in ens)
52  self.pts = set((e.params.ptMin, e.params.ptMax) for e in ens)
53  self.discrs = set((e.params.discrMin, e.params.discrMax)
54  for e in ens
55  if e.params.operatingPoint == 3)
56 
57  self.ETA_MIN = -2.4
58  self.ETA_MAX = 2.4
59  self.PT_MIN = min(e.params.ptMin for e in ens)
60  self.PT_MAX = max(e.params.ptMax for e in ens)
61  if any(e.params.operatingPoint == 3 for e in ens):
62  self.DISCR_MIN = min(
63  e.params.discrMin
64  for e in ens
65  if e.params.operatingPoint == 3
66  )
67  self.DISCR_MAX = max(
68  e.params.discrMax
69  for e in ens
70  if e.params.operatingPoint == 3
71  )
72  else:
73  self.DISCR_MIN = 0.
74  self.DISCR_MAX = 1.
75 
76  # test points for variable data (using bound +- epsilon)
77  eps = 1e-4
78  eta_test_points = list(itertools.ifilter(
79  lambda x: self.ETA_MIN < x < self.ETA_MAX,
80  itertools.chain(
81  (a + eps for a, _ in self.etas),
82  (a - eps for a, _ in self.etas),
83  (b + eps for _, b in self.etas),
84  (b - eps for _, b in self.etas),
85  (self.ETA_MIN + eps, self.ETA_MAX - eps),
86  )
87  ))
88  abseta_test_points = list(itertools.ifilter(
89  lambda x: 0. < x < self.ETA_MAX,
90  itertools.chain(
91  (a + eps for a, _ in self.etas),
92  (a - eps for a, _ in self.etas),
93  (b + eps for _, b in self.etas),
94  (b - eps for _, b in self.etas),
95  (eps, self.ETA_MAX - eps),
96  )
97  ))
98  pt_test_points = list(itertools.ifilter(
99  lambda x: self.PT_MIN < x < self.PT_MAX,
100  itertools.chain(
101  (a + eps for a, _ in self.pts),
102  (a - eps for a, _ in self.pts),
103  (b + eps for _, b in self.pts),
104  (b - eps for _, b in self.pts),
105  (self.PT_MIN + eps, self.PT_MAX - eps),
106  )
107  ))
108  discr_test_points = list(itertools.ifilter(
109  lambda x: self.DISCR_MIN < x < self.DISCR_MAX,
110  itertools.chain(
111  (a + eps for a, _ in self.discrs),
112  (a - eps for a, _ in self.discrs),
113  (b + eps for _, b in self.discrs),
114  (b - eps for _, b in self.discrs),
115  (self.DISCR_MIN + eps, self.DISCR_MAX - eps),
116  )
117  ))
118  # use sets
119  self.eta_test_points = set(round(f, 5) for f in eta_test_points)
120  self.abseta_test_points = set(round(f, 5) for f in abseta_test_points)
121  self.pt_test_points = set(round(f, 5) for f in pt_test_points)
122  self.discr_test_points = set(round(f, 5) for f in discr_test_points)
123 
124  def print_data(self):
125  print("\nFound operating points:")
126  print(self.ops)
127 
128  print("\nFound jet flavors:")
129  print(self.flavs)
130 
131  print("\nFound sys types (need at least 'central', 'up', 'down'; " \
132  "also 'up_SYS'/'down_SYS' compatibility is checked):")
133  print(self.syss)
134 
135  print("\nFound eta ranges: (need everything covered from %g or 0. " \
136  "up to %g):" % (self.ETA_MIN, self.ETA_MAX))
137  print(self.etas)
138 
139  print("\nFound pt ranges: (need everything covered from %g " \
140  "to %g):" % (self.PT_MIN, self.PT_MAX))
141  print(self.pts)
142 
143  print("\nFound discr ranges: (only needed for operatingPoint==3, " \
144  "covered from %g to %g):" % (self.DISCR_MIN, self.DISCR_MAX))
145  print(self.discrs)
146 
147  print("\nTest points for eta (bounds +- epsilon):")
148  print(self.eta_test_points)
149 
150  print("\nTest points for pt (bounds +- epsilon):")
151  print(self.pt_test_points)
152 
153  print("\nTest points for discr (bounds +- epsilon):")
155  print("")
156 
157 
158 def get_data_csv(csv_data):
159  # grab measurement types
160  meas_types = set(
161  l.split(',')[1].strip()
162  for l in csv_data
163  if len(l.split()) == 11
164  )
165 
166  # grab operating points
167  ops = set(
168  int(l.split(',')[0])
169  for l in csv_data
170  if len(l.split()) == 11
171  ) if separate_by_op else ['all']
172 
173  # grab flavors
174  flavs = set(
175  int(l.split(',')[3])
176  for l in csv_data
177  if len(l.split()) == 11
178  ) if separate_by_flav else ['all']
179 
180  # make loaders and filter empty ones
181  lds = list(
182  DataLoader(csv_data, mt, op, fl)
183  for mt in meas_types
184  for op in ops
185  for fl in flavs
186  )
187  lds = [d for d in lds if d.entries]
188  return lds
189 
190 
191 def get_data(filename):
192  with open(filename) as f:
193  csv_data = f.readlines()
194  if not (csv_data and "OperatingPoint" in csv_data[0]):
195  print("Data file does not contain typical header: %s. Exit" % filename)
196  return False
197  csv_data.pop(0) # remove header
198  return get_data_csv(csv_data)
def __init__(self, csv_data, measurement_type, operating_point, flavour)
Definition: dataLoader.py:21
bool any(const std::vector< T > &v, const T &what)
Definition: ECalSD.cc:37
def get_data_csv(csv_data)
Definition: dataLoader.py:158
void print(TMatrixD &m, const char *label=nullptr, bool mathematicaFormat=false)
Definition: Utilities.cc:47
def get_data(filename)
Definition: dataLoader.py:191
def exit(msg="")