CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
helpers.py
Go to the documentation of this file.
1 import FWCore.ParameterSet.Config as cms
2 import sys
3 
4 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
5 ## and replacing that to a given value
6 
7 def addESProducers(process,config):
8  config = config.replace("/",".")
9  #import RecoBTag.Configuration.RecoBTag_cff as btag
10  #print btag
11  module = __import__(config)
12  for name in dir(sys.modules[config]):
13  item = getattr(sys.modules[config],name)
14  if isinstance(item,_Labelable) and not isinstance(item,_ModuleSequenceType) and not name.startswith('_') and not (name == "source" or name == "looper" or name == "subProcess") and not type(item) is cms.PSet:
15  if 'ESProducer' in item.type_():
16  setattr(process,name,item)
17 
18 def loadWithPostfix(process,moduleName,postfix=''):
19  moduleName = moduleName.replace("/",".")
20  module = __import__(moduleName)
21  #print module.PatAlgos.patSequences_cff.patDefaultSequence
22  extendWithPostfix(process,sys.modules[moduleName],postfix)
23 
24 def extendWithPostfix(process,other,postfix,items=()):
25  """Look in other and find types which we can use"""
26  # enable explicit check to avoid overwriting of existing objects
27  #__dict__['_Process__InExtendCall'] = True
28 
29  seqs = dict()
30  sequence = cms.Sequence()
31  sequence._moduleLabels = []
32  sequence.setLabel('tempSequence')
33  for name in dir(other):
34  #'from XX import *' ignores these, and so should we.
35  if name.startswith('_'):
36  continue
37  item = getattr(other,name)
38  if name == "source" or name == "looper" or name == "subProcess":
39  continue
40  elif isinstance(item,cms._ModuleSequenceType):
41  continue
42  elif isinstance(item,cms.Schedule):
43  continue
44  elif isinstance(item,cms.VPSet) or isinstance(item,cms.PSet):
45  continue
46  elif isinstance(item,cms._Labelable):
47  if not item.hasLabel_():
48  item.setLabel(name)
49  if postfix != '':
50  newModule = item.clone()
51  if isinstance(item,cms.ESProducer):
52  newLabel = item.label()
53  newName =name
54  else:
55  if 'TauDiscrimination' in name:
56  process.__setattr__(name,item)
57  newLabel = item.label()+postfix
58  newName = name+postfix
59  process.__setattr__(newName,newModule)
60  if isinstance(newModule, cms._Sequenceable) and not newName == name:
61  sequence +=getattr(process,newName)
62  sequence._moduleLabels.append(item.label())
63  else:
64  process.__setattr__(name,item)
65 
66  if postfix != '':
67  for label in sequence._moduleLabels:
68  massSearchReplaceAnyInputTag(sequence, label, label+postfix,verbose=False,moduleLabelOnly=True)
69 
70 def applyPostfix(process, label, postfix):
71  result = None
72  if hasattr(process, label+postfix):
73  result = getattr(process, label + postfix)
74  else:
75  raise ValueError("Error in <applyPostfix>: No module of name = %s attached to process !!" % (label + postfix))
76  return result
77 
78 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
79  labels = __labelsInSequence(process, sequenceLabel, postfix)
80  if target+postfix in labels:
81  getattr(process, sequenceLabel+postfix).remove(
82  getattr(process, target+postfix)
83  )
84 
85 def __labelsInSequence(process, sequenceLabel, postfix=""):
86  result = [ m.label()[:-len(postfix)] for m in listModules( getattr(process,sequenceLabel+postfix))]
87  result.extend([ m.label()[:-len(postfix)] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
88  if postfix == "":
89  result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
90  result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
91  return result
92 
94  """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
95  def __init__(self,paramName,paramSearch,paramValue,verbose=False):
96  self._paramName = paramName
97  self._paramValue = paramValue
98  self._paramSearch = paramSearch
99  self._verbose = verbose
100  def enter(self,visitee):
101  if (hasattr(visitee,self._paramName)):
102  if getattr(visitee,self._paramName) == self._paramSearch:
103  if self._verbose:print "Replaced %s.%s: %s => %s" % (visitee,self._paramName,getattr(visitee,self._paramName),self._paramValue)
104  setattr(visitee,self._paramName,self._paramValue)
105  def leave(self,visitee):
106  pass
107 
108 class MassSearchReplaceAnyInputTagVisitor(object):
109  """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
110  It will climb down within PSets, VPSets and VInputTags to find its target"""
111  def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False,skipLabelTest=False):
112  self._paramSearch = self.standardizeInputTagFmt(paramSearch)
113  self._paramReplace = self.standardizeInputTagFmt(paramReplace)
114  self._moduleName = ''
115  self._verbose=verbose
116  self._moduleLabelOnly=moduleLabelOnly
117  self._skipLabelTest=skipLabelTest
118  def doIt(self,pset,base):
119  if isinstance(pset, cms._Parameterizable):
120  for name in pset.parameterNames_():
121  # if I use pset.parameters_().items() I get copies of the parameter values
122  # so I can't modify the nested pset
123  value = getattr(pset,name)
124  type = value.pythonTypeName()
125  if type == 'cms.PSet':
126  self.doIt(value,base+"."+name)
127  elif type == 'cms.VPSet':
128  for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
129  elif type == 'cms.VInputTag':
130  for (i,n) in enumerate(value):
131  # VInputTag can be declared as a list of strings, so ensure that n is formatted correctly
132  n = self.standardizeInputTagFmt(n)
133  if (n == self._paramSearch):
134  if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace)
135  value[i] = self._paramReplace
136  elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
137  nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
138  if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep)
139  value[i] = nrep
140  elif type.endswith('.InputTag'):
141  if value == self._paramSearch:
142  if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace)
143  from copy import deepcopy
144  if 'untracked' in type:
145  setattr(pset, name, cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
146  self._paramReplace.getProductInstanceLabel(),
147  self._paramReplace.getProcessName()))
148  else:
149  setattr(pset, name, deepcopy(self._paramReplace) )
150  elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
151  from copy import deepcopy
152  repl = deepcopy(getattr(pset, name))
153  repl.moduleLabel = self._paramReplace.moduleLabel
154  setattr(pset, name, repl)
155  if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, value, repl)
156 
157 
158  @staticmethod
159  def standardizeInputTagFmt(inputTag):
160  ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
161  if not isinstance(inputTag, cms.InputTag):
162  return cms.InputTag(inputTag)
163  return inputTag
164 
165  def enter(self,visitee):
166  label = ''
167  if (not self._skipLabelTest):
168  try: label = visitee.label_()
169  except AttributeError: label = '<Module not in a Process>'
170  else:
171  label = '<Module label not tested>'
172  self.doIt(visitee, label)
173  def leave(self,visitee):
174  pass
175 
176 #FIXME name is not generic enough now
177 class GatherAllModulesVisitor(object):
178  """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
179  def __init__(self, gatheredInstance=cms._Module):
180  self._modules = []
181  self._gatheredInstance= gatheredInstance
182  def enter(self,visitee):
183  if isinstance(visitee,self._gatheredInstance):
184  self._modules.append(visitee)
185  def leave(self,visitee):
186  pass
187  def modules(self):
188  return self._modules
189 
191  """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
192  All modules and sequences are cloned and a postfix is added"""
193  def __init__(self, process, label, postfix, removePostfix=""):
194  self._process = process
195  self._postfix = postfix
196  self._removePostfix = removePostfix
197  self._moduleLabels = []
198  self._clonedSequence = cms.Sequence()
199  setattr(process, self._newLabel(label), self._clonedSequence)
200 
201  def enter(self, visitee):
202  if isinstance(visitee, cms._Module):
203  label = visitee.label()
204  newModule = None
205  if label in self._moduleLabels: # has the module already been cloned ?
206  newModule = getattr(self._process, self._newLabel(label))
207  else:
208  self._moduleLabels.append(label)
209  newModule = visitee.clone()
210  setattr(self._process, self._newLabel(label), newModule)
211  self.__appendToTopSequence(newModule)
212 
213  def leave(self, visitee):
214  pass
215 
216  def clonedSequence(self):
217  for label in self._moduleLabels:
218  massSearchReplaceAnyInputTag(self._clonedSequence, label, self._newLabel(label), moduleLabelOnly=True, verbose=False)
219  self._moduleLabels = [] # prevent the InputTag replacement next time the 'clonedSequence' function is called.
220  return self._clonedSequence
221 
222  def _newLabel(self, label):
223  if self._removePostfix != "":
224  if label[-len(self._removePostfix):] == self._removePostfix:
225  label = label[0:-len(self._removePostfix)]
226  else:
227  raise StandardError("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
228  return label + self._postfix
229 
230  def __appendToTopSequence(self, visitee):
231  self._clonedSequence += visitee
232 
234  """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
235  def __init__(self,paramName,paramSearch):
236  self._paramName = paramName
237  self._paramSearch = paramSearch
238  self._modules = []
239  def enter(self,visitee):
240  if (hasattr(visitee,self._paramName)):
241  if getattr(visitee,self._paramName) == self._paramSearch:
242  self._modules.append(visitee)
243  def leave(self,visitee):
244  pass
245  def modules(self):
246  return self._modules
247 
248 
249 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
250  sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
251 
252 def listModules(sequence):
253  visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
254  sequence.visit(visitor)
255  return visitor.modules()
256 
257 def listSequences(sequence):
258  visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
259  sequence.visit(visitor)
260  return visitor.modules()
261 
262 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False,skipLabelTest=False) :
263  """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
264  sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly,skipLabelTest=skipLabelTest))
265 
266 def jetCollectionString(prefix='', algo='', type=''):
267  """
268  ------------------------------------------------------------------
269  return the string of the jet collection module depending on the
270  input vaules. The default return value will be 'patAK5CaloJets'.
271 
272  algo : indicating the algorithm type of the jet [expected are
273  'AK5', 'IC5', 'SC7', ...]
274  type : indicating the type of constituents of the jet [expec-
275  ted are 'Calo', 'PFlow', 'JPT', ...]
276  prefix : prefix indicating the type of pat collection module (ex-
277  pected are '', 'selected', 'clean').
278  ------------------------------------------------------------------
279  """
280  if(prefix==''):
281  jetCollectionString ='pat'
282  else:
283  jetCollectionString =prefix
284  jetCollectionString+='Pat'
285  jetCollectionString+='Jets'
286  jetCollectionString+=algo
287  jetCollectionString+=type
288  return jetCollectionString
289 
290 def contains(sequence, moduleName):
291  """
292  ------------------------------------------------------------------
293  return True if a module with name 'module' is contained in the
294  sequence with name 'sequence' and False otherwise. This version
295  is not so nice as it also returns True for any substr of the name
296  of a contained module.
297 
298  sequence : sequence [e.g. process.patDefaultSequence]
299  module : module name as a string
300  ------------------------------------------------------------------
301  """
302  return not sequence.__str__().find(moduleName)==-1
303 
304 
305 
306 def cloneProcessingSnippet(process, sequence, postfix, removePostfix=""):
307  """
308  ------------------------------------------------------------------
309  copy a sequence plus the modules and sequences therein
310  both are renamed by getting a postfix
311  input tags are automatically adjusted
312  ------------------------------------------------------------------
313  """
314  result = sequence
315  if not postfix == "":
316  visitor = CloneSequenceVisitor(process, sequence.label(), postfix, removePostfix)
317  sequence.visit(visitor)
318  result = visitor.clonedSequence()
319  return result
320 
321 if __name__=="__main__":
322  import unittest
323  class TestModuleCommand(unittest.TestCase):
324  def setUp(self):
325  """Nothing to do """
326  pass
327  def testCloning(self):
328  p = cms.Process("test")
329  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
330  p.b = cms.EDProducer("b", src=cms.InputTag("a"))
331  p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
332  p.s = cms.Sequence(p.a*p.b*p.c *p.a)
333  cloneProcessingSnippet(p, p.s, "New")
334  self.assertEqual(p.dumpPython(),'import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")\n\nprocess.a = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.c = cms.EDProducer("c",\n src = cms.InputTag("b","instance")\n)\n\n\nprocess.cNew = cms.EDProducer("c",\n src = cms.InputTag("bNew","instance")\n)\n\n\nprocess.bNew = cms.EDProducer("b",\n src = cms.InputTag("aNew")\n)\n\n\nprocess.aNew = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.b = cms.EDProducer("b",\n src = cms.InputTag("a")\n)\n\n\nprocess.s = cms.Sequence(process.a*process.b*process.c*process.a)\n\n\nprocess.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew)\n\n\n')
335  def testContains(self):
336  p = cms.Process("test")
337  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
338  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
339  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
340  p.s1 = cms.Sequence(p.a*p.b*p.c)
341  p.s2 = cms.Sequence(p.b*p.c)
342  self.assert_( contains(p.s1, "a") )
343  self.assert_( not contains(p.s2, "a") )
345  self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patFooBarJets')
346  self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatFooBarJets')
347  def testListModules(self):
348  p = cms.Process("test")
349  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
350  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
351  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
352  p.s = cms.Sequence(p.a*p.b*p.c)
353  self.assertEqual([p.a,p.b,p.c], listModules(p.s))
355  p = cms.Process("test")
356  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
357  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
358  p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
359  nested = cms.PSet(src = cms.InputTag("c"))
360  )
361  p.s = cms.Sequence(p.a*p.b*p.c)
362  massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
363  self.assertEqual(cms.InputTag("a"),p.c.src)
364  self.assertNotEqual(cms.InputTag("a"),p.c.nested.src)
366  p = cms.Process("test")
367  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
368  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
369  p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
370  nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c")),
371  nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("b")), cms.PSet(src = cms.InputTag("d"))),
372  vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d"))
373  )
374  p.s = cms.Sequence(p.a*p.b*p.c)
375  massSearchReplaceAnyInputTag(p.s, cms.InputTag("b"), cms.InputTag("new"))
376  self.assertNotEqual(cms.InputTag("new"), p.b.src)
377  self.assertEqual(cms.InputTag("new"), p.c.src)
378  self.assertEqual(cms.InputTag("new"), p.c.nested.src)
379  self.assertEqual(cms.InputTag("new"), p.c.nested.src)
380  self.assertNotEqual(cms.InputTag("new"), p.c.nested.src2)
381  self.assertEqual(cms.InputTag("new"), p.c.nestedv[0].src)
382  self.assertNotEqual(cms.InputTag("new"), p.c.nestedv[1].src)
383  self.assertNotEqual(cms.InputTag("new"), p.c.vec[0])
384  self.assertEqual(cms.InputTag("new"), p.c.vec[1])
385  self.assertNotEqual(cms.InputTag("new"), p.c.vec[2])
386  self.assertNotEqual(cms.InputTag("new"), p.c.vec[3])
387 
388  unittest.main()
def listModules
Definition: helpers.py:252
def listSequences
Definition: helpers.py:257
def addESProducers
Helpers to perform some technically boring tasks like looking for all modules with a given parameter ...
Definition: helpers.py:7
def cloneProcessingSnippet
Definition: helpers.py:306
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:7
def jetCollectionString
Definition: helpers.py:266
def applyPostfix
Definition: helpers.py:70
def massSearchReplaceAnyInputTag
Definition: helpers.py:262
def massSearchReplaceParam
Definition: helpers.py:249
def extendWithPostfix
Definition: helpers.py:24
def loadWithPostfix
Definition: helpers.py:18
list object
Definition: dbtoconf.py:77
if(dp >Float(M_PI)) dp-
def removeIfInSequence
Definition: helpers.py:78
dbl *** dir
Definition: mlp_gen.cc:35
def __labelsInSequence
Definition: helpers.py:85
def testMassSearchReplaceAnyInputTag
Definition: helpers.py:365
def contains
Definition: helpers.py:290