CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
helpers.py
Go to the documentation of this file.
1 import FWCore.ParameterSet.Config as cms
2 
3 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
4 ## and replacing that to a given value
5 
6 def applyPostfix(process, label, postfix):
7  ''' If a module is in patDefaultSequence use the cloned module.
8  Will crash if patDefaultSequence has not been cloned with 'postfix' beforehand'''
9  result = None
10  defaultLabels = __labelsInSequence(process, "patDefaultSequence", postfix)
11  if hasattr(process, "patPF2PATSequence"):
12  defaultLabels = __labelsInSequence(process, "patPF2PATSequence", postfix)
13  if label in defaultLabels and hasattr(process, label+postfix):
14  result = getattr(process, label+postfix)
15  elif hasattr(process, label):
16  print "WARNING: called applyPostfix for module/sequence %s which is not in patDefaultSequence%s!"%(label,postfix)
17  result = getattr(process, label)
18  return result
19 
20 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
21  labels = __labelsInSequence(process, sequenceLabel, postfix)
22  if target+postfix in labels:
23  getattr(process, sequenceLabel+postfix).remove(
24  getattr(process, target+postfix)
25  )
26 
27 def __labelsInSequence(process, sequenceLabel, postfix=""):
28  result = [ m.label()[:-len(postfix)] for m in listModules( getattr(process,sequenceLabel+postfix))]
29  result.extend([ m.label()[:-len(postfix)] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
30  if postfix == "":
31  result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
32  result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
33  return result
34 
36  """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
37  def __init__(self,paramName,paramSearch,paramValue,verbose=False):
38  self._paramName = paramName
39  self._paramValue = paramValue
40  self._paramSearch = paramSearch
41  self._verbose = verbose
42  def enter(self,visitee):
43  if (hasattr(visitee,self._paramName)):
44  if getattr(visitee,self._paramName) == self._paramSearch:
45  if self._verbose:print "Replaced %s.%s: %s => %s" % (visitee,self._paramName,getattr(visitee,self._paramName),self._paramValue)
46  setattr(visitee,self._paramName,self._paramValue)
47  def leave(self,visitee):
48  pass
49 
50 class MassSearchReplaceAnyInputTagVisitor(object):
51  """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
52  It will climb down within PSets, VPSets and VInputTags to find its target"""
53  def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False):
54  self._paramSearch = self.standardizeInputTagFmt(paramSearch)
55  self._paramReplace = self.standardizeInputTagFmt(paramReplace)
56  self._moduleName = ''
57  self._verbose=verbose
58  self._moduleLabelOnly=moduleLabelOnly
59  def doIt(self,pset,base):
60  if isinstance(pset, cms._Parameterizable):
61  for name in pset.parameterNames_():
62  # if I use pset.parameters_().items() I get copies of the parameter values
63  # so I can't modify the nested pset
64  value = getattr(pset,name)
65  type = value.pythonTypeName()
66  if type == 'cms.PSet':
67  self.doIt(value,base+"."+name)
68  elif type == 'cms.VPSet':
69  for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
70  elif type == 'cms.VInputTag':
71  for (i,n) in enumerate(value):
72  # VInputTag can be declared as a list of strings, so ensure that n is formatted correctly
73  n = self.standardizeInputTagFmt(n)
74  if (n == self._paramSearch):
75  if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace)
76  value[i] = self._paramReplace
77  elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
78  nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
79  if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep)
80  value[i] = nrep
81  elif type.endswith('.InputTag'):
82  if value == self._paramSearch:
83  if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace)
84  from copy import deepcopy
85  if 'untracked' in type:
86  setattr(pset, name, cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
87  self._paramReplace.getProductInstanceLabel(),
88  self._paramReplace.getProcessName()))
89  else:
90  setattr(pset, name, deepcopy(self._paramReplace) )
91  elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
92  from copy import deepcopy
93  repl = deepcopy(getattr(pset, name))
94  repl.moduleLabel = self._paramReplace.moduleLabel
95  setattr(pset, name, repl)
96  if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, value, repl)
97 
98 
99  @staticmethod
100  def standardizeInputTagFmt(inputTag):
101  ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
102  if not isinstance(inputTag, cms.InputTag):
103  return cms.InputTag(inputTag)
104  return inputTag
105 
106  def enter(self,visitee):
107  label = ''
108  try: label = visitee.label_()
109  except AttributeError: label = '<Module not in a Process>'
110  self.doIt(visitee, label)
111  def leave(self,visitee):
112  pass
113 
114 #FIXME name is not generic enough now
115 class GatherAllModulesVisitor(object):
116  """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
117  def __init__(self, gatheredInstance=cms._Module):
118  self._modules = []
119  self._gatheredInstance= gatheredInstance
120  def enter(self,visitee):
121  if isinstance(visitee,self._gatheredInstance):
122  self._modules.append(visitee)
123  def leave(self,visitee):
124  pass
125  def modules(self):
126  return self._modules
127 
129  """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
130  All modules and sequences are cloned and a postfix is added"""
131  def __init__(self, process, label, postfix):
132  self._process = process
133  self._postfix = postfix
134  self._sequenceStack = [label]
135  self._moduleLabels = []
136  self._sequenceLabels = []
137  self._waitForSequenceToClose = None # modules will only be cloned or added if this is None
138 
139  def enter(self,visitee):
140  if not self._waitForSequenceToClose is None:
141  return #we are in a already cloned sequence
142  if isinstance(visitee,cms._Module):
143  label = visitee.label()
144  newModule = None
145  if label in self._moduleLabels:
146  newModule = getattr(self._process, label+self._postfix)
147  else:
148  self._moduleLabels.append(label)
149 
150  newModule = visitee.clone()
151  setattr(self._process, label+self._postfix, newModule)
152  self.__appendToTopSequence(newModule)
153 
154  if isinstance(visitee,cms.Sequence):
155  if visitee.label() in self._sequenceLabels: # is the sequence allready cloned?
156  self._waitForSequenceToClose = visitee.label()
157  self._sequenceStack.append( getattr(self._process, visitee.label()+self._postfix) )
158  else:
159  self._sequenceStack.append(visitee.label())#save desired label as placeholder until we have a module to create the sequence
160 
161  def leave(self,visitee):
162  if isinstance(visitee,cms.Sequence):
163  if self._waitForSequenceToClose == visitee.label():
164  self._waitForSequenceToClose = None
165  if not isinstance(self._sequenceStack[-1], cms.Sequence):
166  raise StandardError, "empty Sequence encountered during cloneing. sequnece stack: %s"%self._sequenceStack
167  self.__appendToTopSequence( self._sequenceStack.pop() )
168 
169  def clonedSequence(self):
170  if not len(self._sequenceStack) == 1:
171  raise StandardError, "someting went wrong, the sequence stack looks like: %s"%self._sequenceStack
172  for label in self._moduleLabels:
173  massSearchReplaceAnyInputTag(self._sequenceStack[-1], label, label+self._postfix, moduleLabelOnly=True, verbose=False)
174  self._moduleLabels = [] #prevent the InputTag replacement next time this is called.
175  return self._sequenceStack[-1]
176 
177  def __appendToTopSequence(self, visitee):#this is darn ugly because empty cms.Sequences are not supported
178  if isinstance(self._sequenceStack[-1], basestring):#we have the name of an empty sequence on the stack. create it!
179  oldSequenceLabel = self._sequenceStack.pop()
180  newSequenceLabel = oldSequenceLabel + self._postfix
181  self._sequenceStack.append(cms.Sequence(visitee))
182  if hasattr(self._process, newSequenceLabel):
183  raise StandardError("Cloning the sequence "+self._sequenceStack[-1].label()+" would overwrite existing object." )
184  setattr(self._process, newSequenceLabel, self._sequenceStack[-1])
185  self._sequenceLabels.append(oldSequenceLabel)
186  else:
187  self._sequenceStack[-1] += visitee
188 
190  """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
191  def __init__(self,paramName,paramSearch):
192  self._paramName = paramName
193  self._paramSearch = paramSearch
194  self._modules = []
195  def enter(self,visitee):
196  if (hasattr(visitee,self._paramName)):
197  if getattr(visitee,self._paramName) == self._paramSearch:
198  self._modules.append(visitee)
199  def leave(self,visitee):
200  pass
201  def modules(self):
202  return self._modules
203 
204 
205 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
206  sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
207 
208 def listModules(sequence):
209  visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
210  sequence.visit(visitor)
211  return visitor.modules()
212 
213 def listSequences(sequence):
214  visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
215  sequence.visit(visitor)
216  return visitor.modules()
217 
218 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False) :
219  """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
220  sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly))
221 
222 def jetCollectionString(prefix='', algo='', type=''):
223  """
224  ------------------------------------------------------------------
225  return the string of the jet collection module depending on the
226  input vaules. The default return value will be 'patAK5CaloJets'.
227 
228  algo : indicating the algorithm type of the jet [expected are
229  'AK5', 'IC5', 'SC7', ...]
230  type : indicating the type of constituents of the jet [expec-
231  ted are 'Calo', 'PFlow', 'JPT', ...]
232  prefix : prefix indicating the type of pat collection module (ex-
233  pected are '', 'selected', 'clean').
234  ------------------------------------------------------------------
235  """
236  if(prefix==''):
237  jetCollectionString ='pat'
238  else:
239  jetCollectionString =prefix
240  jetCollectionString+='Pat'
241  jetCollectionString+='Jets'
242  jetCollectionString+=algo
243  jetCollectionString+=type
244  return jetCollectionString
245 
246 def contains(sequence, moduleName):
247  """
248  ------------------------------------------------------------------
249  return True if a module with name 'module' is contained in the
250  sequence with name 'sequence' and False otherwise. This version
251  is not so nice as it also returns True for any substr of the name
252  of a contained module.
253 
254  sequence : sequence [e.g. process.patDefaultSequence]
255  module : module name as a string
256  ------------------------------------------------------------------
257  """
258  return not sequence.__str__().find(moduleName)==-1
259 
260 
261 
262 def cloneProcessingSnippet(process, sequence, postfix):
263  """
264  ------------------------------------------------------------------
265  copy a sequence plus the modules and sequences therein
266  both are renamed by getting a postfix
267  input tags are automatically adjusted
268  ------------------------------------------------------------------
269  """
270  result = sequence
271  if not postfix == "":
272  visitor = CloneSequenceVisitor(process,sequence.label(),postfix)
273  sequence.visit(visitor)
274  result = visitor.clonedSequence()
275  return result
276 
277 if __name__=="__main__":
278  import unittest
279  class TestModuleCommand(unittest.TestCase):
280  def setUp(self):
281  """Nothing to do """
282  pass
283  def testCloning(self):
284  p = cms.Process("test")
285  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
286  p.b = cms.EDProducer("b", src=cms.InputTag("a"))
287  p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
288  p.s = cms.Sequence(p.a*p.b*p.c *p.a)
289  cloneProcessingSnippet(p, p.s, "New")
290  self.assertEqual(p.dumpPython(),'import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")\n\nprocess.a = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.c = cms.EDProducer("c",\n src = cms.InputTag("b","instance")\n)\n\n\nprocess.cNew = cms.EDProducer("c",\n src = cms.InputTag("bNew","instance")\n)\n\n\nprocess.bNew = cms.EDProducer("b",\n src = cms.InputTag("aNew")\n)\n\n\nprocess.aNew = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.b = cms.EDProducer("b",\n src = cms.InputTag("a")\n)\n\n\nprocess.s = cms.Sequence(process.a*process.b*process.c*process.a)\n\n\nprocess.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew)\n\n\n')
291  def testContains(self):
292  p = cms.Process("test")
293  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
294  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
295  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
296  p.s1 = cms.Sequence(p.a*p.b*p.c)
297  p.s2 = cms.Sequence(p.b*p.c)
298  self.assert_( contains(p.s1, "a") )
299  self.assert_( not contains(p.s2, "a") )
301  self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patFooBarJets')
302  self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatFooBarJets')
303  def testListModules(self):
304  p = cms.Process("test")
305  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
306  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
307  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
308  p.s = cms.Sequence(p.a*p.b*p.c)
309  self.assertEqual([p.a,p.b,p.c], listModules(p.s))
311  p = cms.Process("test")
312  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
313  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
314  p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
315  nested = cms.PSet(src = cms.InputTag("c"))
316  )
317  p.s = cms.Sequence(p.a*p.b*p.c)
318  massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
319  self.assertEqual(cms.InputTag("a"),p.c.src)
320  self.assertNotEqual(cms.InputTag("a"),p.c.nested.src)
322  p = cms.Process("test")
323  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
324  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
325  p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
326  nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c")),
327  nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("b")), cms.PSet(src = cms.InputTag("d"))),
328  vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d"))
329  )
330  p.s = cms.Sequence(p.a*p.b*p.c)
331  massSearchReplaceAnyInputTag(p.s, cms.InputTag("b"), cms.InputTag("new"))
332  self.assertNotEqual(cms.InputTag("new"), p.b.src)
333  self.assertEqual(cms.InputTag("new"), p.c.src)
334  self.assertEqual(cms.InputTag("new"), p.c.nested.src)
335  self.assertEqual(cms.InputTag("new"), p.c.nested.src)
336  self.assertNotEqual(cms.InputTag("new"), p.c.nested.src2)
337  self.assertEqual(cms.InputTag("new"), p.c.nestedv[0].src)
338  self.assertNotEqual(cms.InputTag("new"), p.c.nestedv[1].src)
339  self.assertNotEqual(cms.InputTag("new"), p.c.vec[0])
340  self.assertEqual(cms.InputTag("new"), p.c.vec[1])
341  self.assertNotEqual(cms.InputTag("new"), p.c.vec[2])
342  self.assertNotEqual(cms.InputTag("new"), p.c.vec[3])
343 
344 
345  unittest.main()
def listModules
Definition: helpers.py:208
def listSequences
Definition: helpers.py:213
def cloneProcessingSnippet
Definition: helpers.py:262
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:7
def jetCollectionString
Definition: helpers.py:222
def applyPostfix
Helpers to perform some technically boring tasks like looking for all modules with a given parameter ...
Definition: helpers.py:6
def massSearchReplaceAnyInputTag
Definition: helpers.py:218
def massSearchReplaceParam
Definition: helpers.py:205
list object
Definition: dbtoconf.py:77
perl if(1 lt scalar(@::datatypes))
Definition: edlooper.cc:31
def removeIfInSequence
Definition: helpers.py:20
def __labelsInSequence
Definition: helpers.py:27
def testMassSearchReplaceAnyInputTag
Definition: helpers.py:321
def contains
Definition: helpers.py:246