CMS 3D CMS Logo

helpers.py
Go to the documentation of this file.
1 from __future__ import print_function
2 import FWCore.ParameterSet.Config as cms
3 import sys
4 import six
5 
6 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
7 ## and replacing that to a given value
8 
9 # Next two lines are for backward compatibility, the imported functions and
10 # classes used to be defined in this file.
11 from FWCore.ParameterSet.MassReplace import massSearchReplaceAnyInputTag, MassSearchReplaceAnyInputTagVisitor
12 from FWCore.ParameterSet.MassReplace import massSearchReplaceParam, MassSearchParamVisitor, MassSearchReplaceParamVisitor
13 
14 def getPatAlgosToolsTask(process):
15  taskName = "patAlgosToolsTask"
16  if hasattr(process, taskName):
17  task = getattr(process, taskName)
18  if not isinstance(task, cms.Task):
19  raise Exception("patAlgosToolsTask does not have type Task")
20  else:
21  setattr(process, taskName, cms.Task())
22  task = getattr(process, taskName)
23  return task
24 
26  task = getPatAlgosToolsTask(process)
27  process.schedule.associate(task)
28 
29 def addToProcessAndTask(label, module, process, task):
30  setattr(process, label, module)
31  task.add(getattr(process, label))
32 
33 def addESProducers(process,config):
34  config = config.replace("/",".")
35  #import RecoBTag.Configuration.RecoBTag_cff as btag
36  #print btag
37  module = __import__(config)
38  for name in dir(sys.modules[config]):
39  item = getattr(sys.modules[config],name)
40  if isinstance(item,cms._Labelable) and not isinstance(item,cms._ModuleSequenceType) and not name.startswith('_') and not (name == "source" or name == "looper" or name == "subProcess") and not isinstance(item, cms.PSet):
41  if 'ESProducer' in item.type_():
42  setattr(process,name,item)
43 
44 def loadWithPrefix(process,moduleName,prefix='',loadedProducersAndFilters=None):
45  loadWithPrePostfix(process,moduleName,prefix,'',loadedProducersAndFilters)
46 
47 def loadWithPostfix(process,moduleName,postfix='',loadedProducersAndFilters=None):
48  loadWithPrePostfix(process,moduleName,'',postfix,loadedProducersAndFilters)
49 
50 def loadWithPrePostfix(process,moduleName,prefix='',postfix='',loadedProducersAndFilters=None):
51  moduleName = moduleName.replace("/",".")
52  module = __import__(moduleName)
53  #print module.PatAlgos.patSequences_cff.patDefaultSequence
54  extendWithPrePostfix(process,sys.modules[moduleName],prefix,postfix,loadedProducersAndFilters)
55 
56 def addToTask(loadedProducersAndFilters, module):
57  if loadedProducersAndFilters:
58  if isinstance(module, cms.EDProducer) or isinstance(module, cms.EDFilter):
59  loadedProducersAndFilters.add(module)
60 
61 def extendWithPrePostfix(process,other,prefix,postfix,loadedProducersAndFilters=None):
62  """Look in other and find types which we can use"""
63  # enable explicit check to avoid overwriting of existing objects
64  #__dict__['_Process__InExtendCall'] = True
65 
66  if loadedProducersAndFilters:
67  task = getattr(process, loadedProducersAndFilters)
68  if not isinstance(task, cms.Task):
69  raise Exception("extendWithPrePostfix argument must be name of Task type object attached to the process or None")
70  else:
71  task = None
72 
73  sequence = cms.Sequence()
74  sequence._moduleLabels = []
75  for name in dir(other):
76  #'from XX import *' ignores these, and so should we.
77  if name.startswith('_'):
78  continue
79  item = getattr(other,name)
80  if name == "source" or name == "looper" or name == "subProcess":
81  continue
82  elif isinstance(item,cms._ModuleSequenceType):
83  continue
84  elif isinstance(item,cms.Task):
85  continue
86  elif isinstance(item,cms.Schedule):
87  continue
88  elif isinstance(item,cms.VPSet) or isinstance(item,cms.PSet):
89  continue
90  elif isinstance(item,cms._Labelable):
91  if not item.hasLabel_():
92  item.setLabel(name)
93  if prefix != '' or postfix != '':
94  newModule = item.clone()
95  if isinstance(item,cms.ESProducer):
96  newName =name
97  else:
98  if 'TauDiscrimination' in name:
99  process.__setattr__(name,item)
100  addToTask(task, item)
101  newName = prefix+name+postfix
102  process.__setattr__(newName,newModule)
103  addToTask(task, newModule)
104  if isinstance(newModule, cms._Sequenceable) and not newName == name:
105  sequence +=getattr(process,newName)
106  sequence._moduleLabels.append(item.label())
107  else:
108  process.__setattr__(name,item)
109  addToTask(task, item)
110 
111  if prefix != '' or postfix != '':
112  for label in sequence._moduleLabels:
113  massSearchReplaceAnyInputTag(sequence, label, prefix+label+postfix,verbose=False,moduleLabelOnly=True)
114 
115 def applyPostfix(process, label, postfix):
116  result = None
117  if hasattr(process, label+postfix):
118  result = getattr(process, label + postfix)
119  else:
120  raise ValueError("Error in <applyPostfix>: No module of name = %s attached to process !!" % (label + postfix))
121  return result
122 
123 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
124  labels = __labelsInSequence(process, sequenceLabel, postfix, True)
125  if target+postfix in labels:
126  getattr(process, sequenceLabel+postfix).remove(
127  getattr(process, target+postfix)
128  )
129 
130 def __labelsInSequence(process, sequenceLabel, postfix="", keepPostFix=False):
131  position = -len(postfix)
132  if keepPostFix:
133  position = None
134 
135  result = [ m.label()[:position] for m in listModules( getattr(process,sequenceLabel+postfix))]
136  result.extend([ m.label()[:position] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
137  if postfix == "":
138  result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
139  result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
140  return result
141 
142 #FIXME name is not generic enough now
144  """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
145  def __init__(self, gatheredInstance=cms._Module):
146  self._modules = []
147  self._gatheredInstance= gatheredInstance
148  def enter(self,visitee):
149  if isinstance(visitee,self._gatheredInstance):
150  self._modules.append(visitee)
151  def leave(self,visitee):
152  pass
153  def modules(self):
154  return self._modules
155 
157  """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
158  All modules and sequences are cloned and a postfix is added"""
159  def __init__(self, process, label, postfix, removePostfix="", noClones = [], addToTask = False):
160  self._process = process
161  self._postfix = postfix
162  self._removePostfix = removePostfix
163  self._noClones = noClones
164  self._addToTask = addToTask
165  self._moduleLabels = []
166  self._clonedSequence = cms.Sequence()
167  setattr(process, self._newLabel(label), self._clonedSequence)
168  if addToTask:
170 
171  def enter(self, visitee):
172  if isinstance(visitee, cms._Module):
173  label = visitee.label()
174  newModule = None
175  if label in self._noClones: #keep unchanged
176  newModule = getattr(self._process, label)
177  elif label in self._moduleLabels: # has the module already been cloned ?
178  newModule = getattr(self._process, self._newLabel(label))
179  else:
180  self._moduleLabels.append(label)
181  newModule = visitee.clone()
182  setattr(self._process, self._newLabel(label), newModule)
183  if self._addToTask:
184  self._patAlgosToolsTask.add(getattr(self._process, self._newLabel(label)))
185  self.__appendToTopSequence(newModule)
186 
187  def leave(self, visitee):
188  pass
189 
190  def clonedSequence(self):
191  for label in self._moduleLabels:
192  massSearchReplaceAnyInputTag(self._clonedSequence, label, self._newLabel(label), moduleLabelOnly=True, verbose=False)
193  self._moduleLabels = [] # prevent the InputTag replacement next time the 'clonedSequence' function is called.
194  return self._clonedSequence
195 
196  def _newLabel(self, label):
197  if self._removePostfix != "":
198  if label[-len(self._removePostfix):] == self._removePostfix:
199  label = label[0:-len(self._removePostfix)]
200  else:
201  raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
202  return label + self._postfix
203 
204  def __appendToTopSequence(self, visitee):
205  self._clonedSequence += visitee
206 
207 def listModules(sequence):
208  visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
209  sequence.visit(visitor)
210  return visitor.modules()
211 
212 def listSequences(sequence):
213  visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
214  sequence.visit(visitor)
215  return visitor.modules()
216 
217 def jetCollectionString(prefix='', algo='', type=''):
218  """
219  ------------------------------------------------------------------
220  return the string of the jet collection module depending on the
221  input vaules. The default return value will be 'patAK5CaloJets'.
222 
223  algo : indicating the algorithm type of the jet [expected are
224  'AK5', 'IC5', 'SC7', ...]
225  type : indicating the type of constituents of the jet [expec-
226  ted are 'Calo', 'PFlow', 'JPT', ...]
227  prefix : prefix indicating the type of pat collection module (ex-
228  pected are '', 'selected', 'clean').
229  ------------------------------------------------------------------
230  """
231  if(prefix==''):
232  jetCollectionString ='pat'
233  else:
234  jetCollectionString =prefix
235  jetCollectionString+='Pat'
236  jetCollectionString+='Jets'
237  jetCollectionString+=algo
238  jetCollectionString+=type
239  return jetCollectionString
240 
241 def contains(sequence, moduleName):
242  """
243  ------------------------------------------------------------------
244  return True if a module with name 'module' is contained in the
245  sequence with name 'sequence' and False otherwise. This version
246  is not so nice as it also returns True for any substr of the name
247  of a contained module.
248 
249  sequence : sequence [e.g. process.patDefaultSequence]
250  module : module name as a string
251  ------------------------------------------------------------------
252  """
253  return not sequence.__str__().find(moduleName)==-1
254 
255 
256 
257 def cloneProcessingSnippet(process, sequence, postfix, removePostfix="", noClones = [], addToTask = False):
258  """
259  ------------------------------------------------------------------
260  copy a sequence plus the modules and sequences therein
261  both are renamed by getting a postfix
262  input tags are automatically adjusted
263  ------------------------------------------------------------------
264  """
265  result = sequence
266  if not postfix == "":
267  visitor = CloneSequenceVisitor(process, sequence.label(), postfix, removePostfix, noClones, addToTask)
268  sequence.visit(visitor)
269  result = visitor.clonedSequence()
270  return result
271 
272 def listDependencyChain(process, module, sources, verbose=False):
273  """
274  Walk up the dependencies of a module to find any that depend on any of the listed sources
275  """
276  def allDirectInputModules(moduleOrPSet,moduleName,attrName):
277  ret = set()
278  for name,value in six.iteritems(moduleOrPSet.parameters_()):
279  type = value.pythonTypeName()
280  if type == 'cms.PSet':
281  ret.update(allDirectInputModules(value,moduleName,moduleName+"."+name))
282  elif type == 'cms.VPSet':
283  for (i,ps) in enumerate(value):
284  ret.update(allDirectInputModules(ps,moduleName,"%s.%s[%d]"%(moduleName,name,i)))
285  elif type == 'cms.VInputTag':
286  inputs = [ MassSearchReplaceAnyInputTagVisitor.standardizeInputTagFmt(it) for it in value ]
287  inputLabels = [ tag.moduleLabel for tag in inputs if tag.processName == '' or tag.processName == process.name_() ]
288  ret.update(inputLabels)
289  if verbose and inputLabels: print("%s depends on %s via %s" % (moduleName, inputLabels, attrName+"."+name))
290  elif type.endswith('.InputTag'):
291  if value.processName == '' or value.processName == process.name_():
292  ret.add(value.moduleLabel)
293  if verbose: print("%s depends on %s via %s" % (moduleName, value.moduleLabel, attrName+"."+name))
294  ret.discard("")
295  return ret
296  def fillDirectDepGraphs(root,fwdepgraph,revdepgraph):
297  if root.label_() in fwdepgraph: return
298  deps = allDirectInputModules(root,root.label_(),root.label_())
299  fwdepgraph[root.label_()] = []
300  for d in deps:
301  fwdepgraph[root.label_()].append(d)
302  if d not in revdepgraph: revdepgraph[d] = []
303  revdepgraph[d].append(root.label_())
304  depmodule = getattr(process,d,None)
305  if depmodule:
306  fillDirectDepGraphs(depmodule,fwdepgraph,revdepgraph)
307  return (fwdepgraph,revdepgraph)
308  fwdepgraph, revdepgraph = fillDirectDepGraphs(module, {}, {})
309  def flattenRevDeps(flatgraph, revdepgraph, tip):
310  """Make a graph that for each module lists all the ones that depend on it, directly or indirectly"""
311  # don't do it multiple times for the same module
312  if tip in flatgraph: return
313  # if nobody depends on this module, there's nothing to do
314  if tip not in revdepgraph: return
315  # assemble my dependencies, in a depth-first approach
316  mydeps = set()
317  # start taking the direct dependencies of this module
318  for d in revdepgraph[tip]:
319  # process them
320  flattenRevDeps(flatgraph, revdepgraph, d)
321  # then add them and their processed dependencies to our deps
322  mydeps.add(d)
323  if d in flatgraph:
324  mydeps.update(flatgraph[d])
325  flatgraph[tip] = mydeps
326  flatdeps = {}
327  allmodules = set()
328  for s in sources:
329  flattenRevDeps(flatdeps, revdepgraph, s)
330  if s in flatdeps: allmodules.update(f for f in flatdeps[s])
331  livemodules = [ a for a in allmodules if hasattr(process,a) ]
332  if not livemodules: return None
333  modulelist = [livemodules.pop()]
334  for module in livemodules:
335  for i,m in enumerate(modulelist):
336  if module in flatdeps and m in flatdeps[module]:
337  modulelist.insert(i, module)
338  break
339  if module not in modulelist:
340  modulelist.append(module)
341  # Validate
342  for i,m1 in enumerate(modulelist):
343  for j,m2 in enumerate(modulelist):
344  if j <= i: continue
345  if m2 in flatdeps and m1 in flatdeps[m2]:
346  raise RuntimeError("BAD ORDER %s BEFORE %s" % (m1,m2))
347  modules = [ getattr(process,p) for p in modulelist ]
348  #return cms.Sequence(sum(modules[1:],modules[0]))
349  task = cms.Task()
350  for mod in modules:
351  task.add(mod)
352  return task,cms.Sequence(task)
353 
354 def addKeepStatement(process, oldKeep, newKeeps, verbose=False):
355  """Add new keep statements to any PoolOutputModule of the process that has the old keep statements"""
356  for name,out in six.iteritems(process.outputModules):
357  if out.type_() == 'PoolOutputModule' and hasattr(out, "outputCommands"):
358  if oldKeep in out.outputCommands:
359  out.outputCommands += newKeeps
360  if verbose:
361  print("Adding the following keep statements to output module %s: " % name)
362  for k in newKeeps: print("\t'%s'," % k)
363 
364 
365 if __name__=="__main__":
366  import unittest
367  class TestModuleCommand(unittest.TestCase):
368  def setUp(self):
369  """Nothing to do """
370  pass
371  def testCloning(self):
372  p = cms.Process("test")
373  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
374  p.b = cms.EDProducer("b", src=cms.InputTag("a"))
375  p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
376  p.s = cms.Sequence(p.a*p.b*p.c *p.a)
377  cloneProcessingSnippet(p, p.s, "New", addToTask = True)
378  self.assertEqual(p.dumpPython(),
379  """import FWCore.ParameterSet.Config as cms
380 
381 process = cms.Process("test")
382 
383 process.a = cms.EDProducer("a",
384  src = cms.InputTag("gen")
385 )
386 
387 
388 process.aNew = cms.EDProducer("a",
389  src = cms.InputTag("gen")
390 )
391 
392 
393 process.b = cms.EDProducer("b",
394  src = cms.InputTag("a")
395 )
396 
397 
398 process.bNew = cms.EDProducer("b",
399  src = cms.InputTag("aNew")
400 )
401 
402 
403 process.c = cms.EDProducer("c",
404  src = cms.InputTag("b","instance")
405 )
406 
407 
408 process.cNew = cms.EDProducer("c",
409  src = cms.InputTag("bNew","instance")
410 )
411 
412 
413 process.patAlgosToolsTask = cms.Task(process.aNew, process.bNew, process.cNew)
414 
415 
416 process.s = cms.Sequence(process.a+process.b+process.c+process.a)
417 
418 
419 process.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew+process.aNew)
420 
421 
422 """)
423  def testContains(self):
424  p = cms.Process("test")
425  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
426  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
427  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
428  p.s1 = cms.Sequence(p.a*p.b*p.c)
429  p.s2 = cms.Sequence(p.b*p.c)
430  self.assert_( contains(p.s1, "a") )
431  self.assert_( not contains(p.s2, "a") )
433  self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patJetsFooBar')
434  self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatJetsFooBar')
435  def testListModules(self):
436  p = cms.Process("test")
437  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
438  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
439  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
440  p.s = cms.Sequence(p.a*p.b*p.c)
441  self.assertEqual([p.a,p.b,p.c], listModules(p.s))
442 
443  unittest.main()
def extendWithPrePostfix(process, other, prefix, postfix, loadedProducersAndFilters=None)
Definition: helpers.py:61
def cloneProcessingSnippet(process, sequence, postfix, removePostfix="", noClones=[], addToTask=False)
Definition: helpers.py:257
def loadWithPrePostfix(process, moduleName, prefix='', postfix='', loadedProducersAndFilters=None)
Definition: helpers.py:50
def contains(sequence, moduleName)
Definition: helpers.py:241
def leave(self, visitee)
Definition: helpers.py:187
def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag, verbose=False, moduleLabelOnly=False, skipLabelTest=False)
Definition: MassReplace.py:73
def enter(self, visitee)
Definition: helpers.py:148
def addToProcessAndTask(label, module, process, task)
Definition: helpers.py:29
def _newLabel(self, label)
Definition: helpers.py:196
def testJetCollectionString(self)
Definition: helpers.py:432
S & print(S &os, JobReport::InputFile const &f)
Definition: JobReport.cc:66
def removeIfInSequence(process, target, sequenceLabel, postfix="")
Definition: helpers.py:123
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:20
def __init__(self, process, label, postfix, removePostfix="", noClones=[], addToTask=False)
Definition: helpers.py:159
def addKeepStatement(process, oldKeep, newKeeps, verbose=False)
Definition: helpers.py:354
def leave(self, visitee)
Definition: helpers.py:151
def testListModules(self)
Definition: helpers.py:435
def applyPostfix(process, label, postfix)
Definition: helpers.py:115
def listModules(sequence)
Definition: helpers.py:207
def addESProducers(process, config)
Definition: helpers.py:33
def loadWithPrefix(process, moduleName, prefix='', loadedProducersAndFilters=None)
Definition: helpers.py:44
def __init__(self, gatheredInstance=cms._Module)
Definition: helpers.py:145
def addToTask(loadedProducersAndFilters, module)
Definition: helpers.py:56
def listSequences(sequence)
Definition: helpers.py:212
def __labelsInSequence(process, sequenceLabel, postfix="", keepPostFix=False)
Definition: helpers.py:130
def listDependencyChain(process, module, sources, verbose=False)
Definition: helpers.py:272
def enter(self, visitee)
Definition: helpers.py:171
def jetCollectionString(prefix='', algo='', type='')
Definition: helpers.py:217
def remove(d, key, TELL=False)
Definition: MatrixUtil.py:212
def associatePatAlgosToolsTask(process)
Definition: helpers.py:25
def __appendToTopSequence(self, visitee)
Definition: helpers.py:204
def getPatAlgosToolsTask(process)
Definition: helpers.py:14
dbl *** dir
Definition: mlp_gen.cc:35
def loadWithPostfix(process, moduleName, postfix='', loadedProducersAndFilters=None)
Definition: helpers.py:47
def testContains(self)
Definition: helpers.py:423