CMS 3D CMS Logo

collectionMerger.py
Go to the documentation of this file.
2 from PhysicsTools.NanoAODTools.postprocessing.framework.datamodel import Collection
3 import ROOT
4 import numpy as np
5 import itertools
6 ROOT.PyConfig.IgnoreCommandLineOptions = True
7 
8 _rootLeafType2rootBranchType = {
9  'UChar_t': 'b',
10  'Char_t': 'B',
11  'UInt_t': 'i',
12  'Int_t': 'I',
13  'Float_t': 'F',
14  'Double_t': 'D',
15  'ULong64_t': 'l',
16  'Long64_t': 'L',
17  'Bool_t': 'O'
18 }
19 
20 
21 class collectionMerger(Module):
22  def __init__(self,
23  input,
24  output,
25  sortkey=lambda x: x.pt,
26  reverse=True,
27  selector=None,
28  maxObjects=None):
29  self.input = input
30  self.output = output
31  self.nInputs = len(self.input)
32  self.sortkey = lambda obj_j_i1: sortkey(obj_j_i1[0])
33  self.reverse = reverse
34  # pass dict([(collection_name,lambda obj : selection(obj)])
35  self.selector = [(selector[coll] if coll in selector else
36  (lambda x: True))
37  for coll in self.input] if selector else None
38  # save only the first maxObjects objects passing the selection in the merged collection
39  self.maxObjects = maxObjects
40  self.branchType = {}
41  pass
42 
43  def beginJob(self):
44  pass
45 
46  def endJob(self):
47  pass
48 
49  def beginFile(self, inputFile, outputFile, inputTree, wrappedOutputTree):
50 
51  # Find list of activated branches in input tree
52  _brlist_in = inputTree.GetListOfBranches()
53  branches_in = set(
54  [_brlist_in.At(i) for i in range(_brlist_in.GetEntries())])
55  branches_in = [
56  x for x in branches_in if inputTree.GetBranchStatus(x.GetName())
57  ]
58 
59  # Find list of activated branches in output tree
60  _brlist_out = wrappedOutputTree._tree.GetListOfBranches()
61  branches_out = set(
62  [_brlist_out.At(i) for i in range(_brlist_out.GetEntries())])
63  branches_out = [
64  x for x in branches_out
65  if wrappedOutputTree._tree.GetBranchStatus(x.GetName())
66  ]
67 
68  # Use both
69  branches = branches_in + branches_out
70 
71  # Only keep branches with right collection name
72  self.brlist_sep = [
73  self.filterBranchNames(branches, x) for x in self.input
74  ]
75  self.brlist_all = set(itertools.chain(*(self.brlist_sep)))
76 
77  self.is_there = np.zeros(shape=(len(self.brlist_all), self.nInputs),
78  dtype=bool)
79  for bridx, br in enumerate(self.brlist_all):
80  for j in range(self.nInputs):
81  if br in self.brlist_sep[j]:
82  self.is_there[bridx][j] = True
83 
84  # Create output branches
85  self.out = wrappedOutputTree
86  for br in self.brlist_all:
87  self.out.branch("%s_%s" % (self.output, br),
88  _rootLeafType2rootBranchType[self.branchType[br]],
89  lenVar="n%s" % self.output)
90 
91  def endFile(self, inputFile, outputFile, inputTree, wrappedOutputTree):
92  pass
93 
94  def filterBranchNames(self, branches, collection):
95  out = []
96  for br in branches:
97  name = br.GetName()
98  if not name.startswith(collection + '_'):
99  continue
100  out.append(name.replace(collection + '_', ''))
101  self.branchType[out[-1]] = br.FindLeaf(br.GetName()).GetTypeName()
102  return out
103 
104  def analyze(self, event):
105  """process event, return True (go to next module) or False (fail, go to next event)"""
106  coll = [Collection(event, x) for x in self.input]
107  objects = [(coll[j][i], j, i) for j in range(self.nInputs)
108  for i in range(len(coll[j]))]
109  if self.selector:
110  objects = [
111  obj_j_i for obj_j_i in objects
112  if self.selector[obj_j_i[1]](obj_j_i[0])
113  ]
114  objects.sort(key=self.sortkey, reverse=self.reverse)
115  if self.maxObjects:
116  objects = objects[:self.maxObjects]
117  for bridx, br in enumerate(self.brlist_all):
118  out = []
119  for obj, j, i in objects:
120  out.append(getattr(obj, br) if self.is_there[bridx][j] else 0)
121  self.out.fillBranch("%s_%s" % (self.output, br), out)
122  return True
123 
124 
125 # define modules using the syntax 'name = lambda : constructor' to avoid having them loaded when not needed
126 
127 lepMerger = lambda: collectionMerger(input=["Electron", "Muon"],
128  output="Lepton")
129 lepMerger_exampleSelection = lambda: collectionMerger(
130  input=["Electron", "Muon"],
131  output=
132  "Lepton", # this will keep only the two leading leptons among electrons with pt > 20 and muons with pt > 40
133  maxObjects=2,
134  selector=dict([("Electron", lambda x: x.pt > 20),
135  ("Muon", lambda x: x.pt > 40)]),
136 )
def endFile(self, inputFile, outputFile, inputTree, wrappedOutputTree)
def __init__(self, input, output, sortkey=lambda x:x.pt, reverse=True, selector=None, maxObjects=None)
def beginFile(self, inputFile, outputFile, inputTree, wrappedOutputTree)
def filterBranchNames(self, branches, collection)