CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
HiHelperTools.py
Go to the documentation of this file.
2 
3 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
4 ## and replacing that to a given value
5 
6 def applyPostfix(process, label, postfix):
7  ''' If a module is in patHeavyIonDefaultSequence use the cloned module.
8  Will crash if patHeavyIonDefaultSequence has not been cloned with 'postfix' beforehand'''
9  result = None
10  defaultLabels = __labelsInSequence(process, "patHeavyIonDefaultSequence", postfix)
11  if hasattr(process, "patPF2PATSequence"):
12  defaultLabels = __labelsInSequence(process, "patPF2PATSequence", postfix)
13  if label in defaultLabels and hasattr(process, label+postfix):
14  result = getattr(process, label+postfix)
15  elif hasattr(process, label):
16  print "WARNING: called applyPostfix for module/sequence %s which is not in patHeavyIonDefaultSequence%s!"%(label,postfix)
17  result = getattr(process, label)
18  return result
19 
20 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
21  labels = __labelsInSequence(process, sequenceLabel, postfix)
22  if target+postfix in labels:
23  getattr(process, sequenceLabel+postfix).remove(
24  getattr(process, target+postfix)
25  )
26 
27 def __labelsInSequence(process, sequenceLabel, postfix=""):
28  result = [ m.label()[:-len(postfix)] for m in listModules( getattr(process,sequenceLabel+postfix))]
29  result.extend([ m.label()[:-len(postfix)] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
30  if postfix == "":
31  result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
32  result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
33  return result
34 
36  """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
37  def __init__(self,paramName,paramSearch,paramValue,verbose=False):
38  self._paramName = paramName
39  self._paramValue = paramValue
40  self._paramSearch = paramSearch
41  self._verbose = verbose
42  def enter(self,visitee):
43  if (hasattr(visitee,self._paramName)):
44  if getattr(visitee,self._paramName) == self._paramSearch:
45  if self._verbose:print "Replaced %s.%s: %s => %s" % (visitee,self._paramName,getattr(visitee,self._paramName),self._paramValue)
46  setattr(visitee,self._paramName,self._paramValue)
47  def leave(self,visitee):
48  pass
49 
50 class MassSearchReplaceAnyInputTagVisitor(object):
51  """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
52  It will climb down within PSets, VPSets and VInputTags to find its target"""
53  def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False):
54  self._paramSearch = self.standardizeInputTagFmt(paramSearch)
55  self._paramReplace = self.standardizeInputTagFmt(paramReplace)
56  self._moduleName = ''
57  self._verbose=verbose
58  self._moduleLabelOnly=moduleLabelOnly
59  def doIt(self,pset,base):
60  if isinstance(pset, cms._Parameterizable):
61  for name in pset.parameters_().keys():
62  # if I use pset.parameters_().items() I get copies of the parameter values
63  # so I can't modify the nested pset
64  value = getattr(pset,name)
65  type = value.pythonTypeName()
66  if type == 'cms.PSet':
67  self.doIt(value,base+"."+name)
68  elif type == 'cms.VPSet':
69  for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
70  elif type == 'cms.VInputTag':
71  for (i,n) in enumerate(value):
72  # VInputTag can be declared as a list of strings, so ensure that n is formatted correctly
73  n = self.standardizeInputTagFmt(n)
74  if (n == self._paramSearch):
75  if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace)
76  value[i] = self._paramReplace
77  elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
78  nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
79  if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep)
80  value[i] = nrep
81  elif type == 'cms.InputTag':
82  if value == self._paramSearch:
83  if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace)
84  from copy import deepcopy
85  setattr(pset, name, deepcopy(self._paramReplace) )
86  elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
87  from copy import deepcopy
88  repl = deepcopy(getattr(pset, name))
89  repl.moduleLabel = self._paramReplace.moduleLabel
90  setattr(pset, name, repl)
91  if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, value, repl)
92 
93 
94  @staticmethod
95  def standardizeInputTagFmt(inputTag):
96  ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
97  if not isinstance(inputTag, cms.InputTag):
98  return cms.InputTag(inputTag)
99  return inputTag
100 
101  def enter(self,visitee):
102  label = ''
103  try: label = visitee.label()
104  except AttributeError: label = '<Module not in a Process>'
105  self.doIt(visitee, label)
106  def leave(self,visitee):
107  pass
108 
109 #FIXME name is not generic enough now
110 class GatherAllModulesVisitor(object):
111  """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
112  def __init__(self, gatheredInstance=cms._Module):
113  self._modules = []
114  self._gatheredInstance= gatheredInstance
115  def enter(self,visitee):
116  if isinstance(visitee,self._gatheredInstance):
117  self._modules.append(visitee)
118  def leave(self,visitee):
119  pass
120  def modules(self):
121  return self._modules
122 
124  """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
125  All modules and sequences are cloned and a postfix is added"""
126  def __init__(self, process, label, postfix):
127  self._process = process
128  self._postfix = postfix
129  self._sequenceStack = [label]
130  self._moduleLabels = []
131  self._sequenceLabels = []
132  self._waitForSequenceToClose = None # modules will only be cloned or added if this is None
133 
134  def enter(self,visitee):
135  if not self._waitForSequenceToClose is None:
136  return #we are in a already cloned sequence
137  if isinstance(visitee,cms._Module):
138  label = visitee.label()
139  newModule = None
140  if label in self._moduleLabels:
141  newModule = getattr(self._process, label+self._postfix)
142  else:
143  self._moduleLabels.append(label)
144 
145  newModule = visitee.clone()
146  setattr(self._process, label+self._postfix, newModule)
147  self.__appendToTopSequence(newModule)
148 
149  if isinstance(visitee,cms.Sequence):
150  if visitee.label() in self._sequenceLabels: # is the sequence allready cloned?
151  self._waitForSequenceToClose = visitee.label()
152  self._sequenceStack.append( getattr(self._process, visitee.label()+self._postfix) )
153  else:
154  self._sequenceStack.append(visitee.label())#save desired label as placeholder until we have a module to create the sequence
155 
156  def leave(self,visitee):
157  if isinstance(visitee,cms.Sequence):
158  if self._waitForSequenceToClose == visitee.label():
159  self._waitForSequenceToClose = None
160  if not isinstance(self._sequenceStack[-1], cms.Sequence):
161  raise StandardError, "empty Sequence encountered during cloneing. sequnece stack: %s"%self._sequenceStack
162  self.__appendToTopSequence( self._sequenceStack.pop() )
163 
164  def clonedSequence(self):
165  if not len(self._sequenceStack) == 1:
166  raise StandardError, "someting went wrong, the sequence stack looks like: %s"%self._sequenceStack
167  for label in self._moduleLabels:
168  massSearchReplaceAnyInputTag(self._sequenceStack[-1], label, label+self._postfix, moduleLabelOnly=True, verbose=False)
169  self._moduleLabels = [] #prevent the InputTag replacement next time this is called.
170  return self._sequenceStack[-1]
171 
172  def __appendToTopSequence(self, visitee):#this is darn ugly because empty cms.Sequences are not supported
173  if isinstance(self._sequenceStack[-1], basestring):#we have the name of an empty sequence on the stack. create it!
174  oldSequenceLabel = self._sequenceStack.pop()
175  newSequenceLabel = oldSequenceLabel + self._postfix
176  self._sequenceStack.append(cms.Sequence(visitee))
177  if hasattr(self._process, newSequenceLabel):
178  raise StandardError("Cloning the sequence "+self._sequenceStack[-1].label()+" would overwrite existing object." )
179  setattr(self._process, newSequenceLabel, self._sequenceStack[-1])
180  self._sequenceLabels.append(oldSequenceLabel)
181  else:
182  self._sequenceStack[-1] += visitee
183 
185  """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
186  def __init__(self,paramName,paramSearch):
187  self._paramName = paramName
188  self._paramSearch = paramSearch
189  self._modules = []
190  def enter(self,visitee):
191  if (hasattr(visitee,self._paramName)):
192  if getattr(visitee,self._paramName) == self._paramSearch:
193  self._modules.append(visitee)
194  def leave(self,visitee):
195  pass
196  def modules(self):
197  return self._modules
198 
199 
200 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue):
201  sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue))
202 
203 def listModules(sequence):
204  visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
205  sequence.visit(visitor)
206  return visitor.modules()
207 
208 def listSequences(sequence):
209  visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
210  sequence.visit(visitor)
211  return visitor.modules()
212 
213 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False) :
214  """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
215  sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly))
216 
217 def jetCollectionString(prefix='', algo='', type=''):
218  """
219  ------------------------------------------------------------------
220  return the string of the jet collection module depending on the
221  input vaules. The default return value will be 'patAK5CaloJets'.
222 
223  algo : indicating the algorithm type of the jet [expected are
224  'AK5', 'IC5', 'SC7', ...]
225  type : indicating the type of constituents of the jet [expec-
226  ted are 'Calo', 'PFlow', 'JPT', ...]
227  prefix : prefix indicating the type of pat collection module (ex-
228  pected are '', 'selected', 'clean').
229  ------------------------------------------------------------------
230  """
231  if(prefix==''):
232  jetCollectionString ='pat'
233  else:
234  jetCollectionString =prefix
235  jetCollectionString+='Pat'
236  jetCollectionString+='Jets'
237  jetCollectionString+=algo
238  jetCollectionString+=type
239  return jetCollectionString
240 
241 def contains(sequence, moduleName):
242  """
243  ------------------------------------------------------------------
244  return True if a module with name 'module' is contained in the
245  sequence with name 'sequence' and False otherwise. This version
246  is not so nice as it also returns True for any substr of the name
247  of a contained module.
248 
249  sequence : sequence [e.g. process.patHeavyIonDefaultSequence]
250  module : module name as a string
251  ------------------------------------------------------------------
252  """
253  return not sequence.__str__().find(moduleName)==-1
254 
255 
256 
257 def cloneProcessingSnippet(process, sequence, postfix):
258  """
259  ------------------------------------------------------------------
260  copy a sequence plus the modules and sequences therein
261  both are renamed by getting a postfix
262  input tags are automatically adjusted
263  ------------------------------------------------------------------
264  """
265  result = sequence
266  if not postfix == "":
267  visitor = CloneSequenceVisitor(process,sequence.label(),postfix)
268  sequence.visit(visitor)
269  result = visitor.clonedSequence()
270  return result
271 
272 if __name__=="__main__":
273  import unittest
274  class TestModuleCommand(unittest.TestCase):
275  def setUp(self):
276  """Nothing to do """
277  pass
278  def testCloning(self):
279  p = cms.Process("test")
280  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
281  p.b = cms.EDProducer("b", src=cms.InputTag("a"))
282  p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
283  p.s = cms.Sequence(p.a*p.b*p.c *p.a)
284  cloneProcessingSnippet(p, p.s, "New")
285  self.assertEqual(p.dumpPython(),'import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")\n\nprocess.a = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.c = cms.EDProducer("c",\n src = cms.InputTag("b","instance")\n)\n\n\nprocess.cNew = cms.EDProducer("c",\n src = cms.InputTag("bNew","instance")\n)\n\n\nprocess.bNew = cms.EDProducer("b",\n src = cms.InputTag("aNew")\n)\n\n\nprocess.aNew = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.b = cms.EDProducer("b",\n src = cms.InputTag("a")\n)\n\n\nprocess.s = cms.Sequence(process.a*process.b*process.c*process.a)\n\n\nprocess.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew)\n\n\n')
286  def testContains(self):
287  p = cms.Process("test")
288  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
289  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
290  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
291  p.s1 = cms.Sequence(p.a*p.b*p.c)
292  p.s2 = cms.Sequence(p.b*p.c)
293  self.assert_( contains(p.s1, "a") )
294  self.assert_( not contains(p.s2, "a") )
296  self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patFooBarJets')
297  self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatFooBarJets')
298  def testListModules(self):
299  p = cms.Process("test")
300  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
301  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
302  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
303  p.s = cms.Sequence(p.a*p.b*p.c)
304  self.assertEqual([p.a,p.b,p.c], listModules(p.s))
306  p = cms.Process("test")
307  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
308  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
309  p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
310  nested = cms.PSet(src = cms.InputTag("c"))
311  )
312  p.s = cms.Sequence(p.a*p.b*p.c)
313  massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
314  self.assertEqual(cms.InputTag("a"),p.c.src)
315  self.assertNotEqual(cms.InputTag("a"),p.c.nested.src)
316 
317  unittest.main()
def cloneProcessingSnippet
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:7
def massSearchReplaceAnyInputTag
def massSearchReplaceParam
def removeIfInSequence
list object
Definition: dbtoconf.py:77
perl if(1 lt scalar(@::datatypes))
Definition: edlooper.cc:31
def applyPostfix
Helpers to perform some technically boring tasks like looking for all modules with a given parameter ...
Definition: HiHelperTools.py:6
def __labelsInSequence