00001 import FWCore.ParameterSet.Config as cms
00002
00003
00004
00005
00006 def applyPostfix(process, label, postfix):
00007 ''' If a module is in patDefaultSequence use the cloned module.
00008 Will crash if patDefaultSequence has not been cloned with 'postfix' beforehand'''
00009 result = None
00010 defaultLabels = __labelsInSequence(process, "patDefaultSequence", postfix)
00011 if hasattr(process, "patPF2PATSequence"):
00012 defaultLabels = __labelsInSequence(process, "patPF2PATSequence", postfix)
00013 if label in defaultLabels and hasattr(process, label+postfix):
00014 result = getattr(process, label+postfix)
00015 elif hasattr(process, label):
00016 print "WARNING: called applyPostfix for module/sequence %s which is not in patDefaultSequence%s!"%(label,postfix)
00017 result = getattr(process, label)
00018 return result
00019
00020 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
00021 labels = __labelsInSequence(process, sequenceLabel, postfix)
00022 if target+postfix in labels:
00023 getattr(process, sequenceLabel+postfix).remove(
00024 getattr(process, target+postfix)
00025 )
00026
00027 def __labelsInSequence(process, sequenceLabel, postfix=""):
00028 result = [ m.label()[:-len(postfix)] for m in listModules( getattr(process,sequenceLabel+postfix))]
00029 result.extend([ m.label()[:-len(postfix)] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
00030 if postfix == "":
00031 result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
00032 result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
00033 return result
00034
00035 class MassSearchReplaceParamVisitor(object):
00036 """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
00037 def __init__(self,paramName,paramSearch,paramValue,verbose=False):
00038 self._paramName = paramName
00039 self._paramValue = paramValue
00040 self._paramSearch = paramSearch
00041 self._verbose = verbose
00042 def enter(self,visitee):
00043 if (hasattr(visitee,self._paramName)):
00044 if getattr(visitee,self._paramName) == self._paramSearch:
00045 if self._verbose:print "Replaced %s.%s: %s => %s" % (visitee,self._paramName,getattr(visitee,self._paramName),self._paramValue)
00046 setattr(visitee,self._paramName,self._paramValue)
00047 def leave(self,visitee):
00048 pass
00049
00050 class MassSearchReplaceAnyInputTagVisitor(object):
00051 """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
00052 It will climb down within PSets, VPSets and VInputTags to find its target"""
00053 def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False):
00054 self._paramSearch = self.standardizeInputTagFmt(paramSearch)
00055 self._paramReplace = self.standardizeInputTagFmt(paramReplace)
00056 self._moduleName = ''
00057 self._verbose=verbose
00058 self._moduleLabelOnly=moduleLabelOnly
00059 def doIt(self,pset,base):
00060 if isinstance(pset, cms._Parameterizable):
00061 for name in pset.parameters_().keys():
00062
00063
00064 value = getattr(pset,name)
00065 type = value.pythonTypeName()
00066 if type == 'cms.PSet':
00067 self.doIt(value,base+"."+name)
00068 elif type == 'cms.VPSet':
00069 for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
00070 elif type == 'cms.VInputTag':
00071 for (i,n) in enumerate(value):
00072
00073 n = self.standardizeInputTagFmt(n)
00074 if (n == self._paramSearch):
00075 if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace)
00076 value[i] = self._paramReplace
00077 elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
00078 nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
00079 if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep)
00080 value[i] = nrep
00081 elif type == 'cms.InputTag':
00082 if value == self._paramSearch:
00083 if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace)
00084 from copy import deepcopy
00085 setattr(pset, name, deepcopy(self._paramReplace) )
00086 elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
00087 from copy import deepcopy
00088 repl = deepcopy(getattr(pset, name))
00089 repl.moduleLabel = self._paramReplace.moduleLabel
00090 setattr(pset, name, repl)
00091 if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, value, repl)
00092
00093
00094 @staticmethod
00095 def standardizeInputTagFmt(inputTag):
00096 ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
00097 if not isinstance(inputTag, cms.InputTag):
00098 return cms.InputTag(inputTag)
00099 return inputTag
00100
00101 def enter(self,visitee):
00102 label = ''
00103 try: label = visitee.label()
00104 except AttributeError: label = '<Module not in a Process>'
00105 self.doIt(visitee, label)
00106 def leave(self,visitee):
00107 pass
00108
00109
00110 class GatherAllModulesVisitor(object):
00111 """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
00112 def __init__(self, gatheredInstance=cms._Module):
00113 self._modules = []
00114 self._gatheredInstance= gatheredInstance
00115 def enter(self,visitee):
00116 if isinstance(visitee,self._gatheredInstance):
00117 self._modules.append(visitee)
00118 def leave(self,visitee):
00119 pass
00120 def modules(self):
00121 return self._modules
00122
00123 class CloneSequenceVisitor(object):
00124 """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
00125 All modules and sequences are cloned and a postfix is added"""
00126 def __init__(self, process, label, postfix):
00127 self._process = process
00128 self._postfix = postfix
00129 self._sequenceStack = [label]
00130 self._moduleLabels = []
00131 self._sequenceLabels = []
00132 self._waitForSequenceToClose = None
00133
00134 def enter(self,visitee):
00135 if not self._waitForSequenceToClose is None:
00136 return
00137 if isinstance(visitee,cms._Module):
00138 label = visitee.label()
00139 newModule = None
00140 if label in self._moduleLabels:
00141 newModule = getattr(self._process, label+self._postfix)
00142 else:
00143 self._moduleLabels.append(label)
00144
00145 newModule = visitee.clone()
00146 setattr(self._process, label+self._postfix, newModule)
00147 self.__appendToTopSequence(newModule)
00148
00149 if isinstance(visitee,cms.Sequence):
00150 if visitee.label() in self._sequenceLabels:
00151 self._waitForSequenceToClose = visitee.label()
00152 self._sequenceStack.append( getattr(self._process, visitee.label()+self._postfix) )
00153 else:
00154 self._sequenceStack.append(visitee.label())
00155
00156 def leave(self,visitee):
00157 if isinstance(visitee,cms.Sequence):
00158 if self._waitForSequenceToClose == visitee.label():
00159 self._waitForSequenceToClose = None
00160 if not isinstance(self._sequenceStack[-1], cms.Sequence):
00161 raise StandardError, "empty Sequence encountered during cloneing. sequnece stack: %s"%self._sequenceStack
00162 self.__appendToTopSequence( self._sequenceStack.pop() )
00163
00164 def clonedSequence(self):
00165 if not len(self._sequenceStack) == 1:
00166 raise StandardError, "someting went wrong, the sequence stack looks like: %s"%self._sequenceStack
00167 for label in self._moduleLabels:
00168 massSearchReplaceAnyInputTag(self._sequenceStack[-1], label, label+self._postfix, moduleLabelOnly=True, verbose=False)
00169 self._moduleLabels = []
00170 return self._sequenceStack[-1]
00171
00172 def __appendToTopSequence(self, visitee):
00173 if isinstance(self._sequenceStack[-1], basestring):
00174 oldSequenceLabel = self._sequenceStack.pop()
00175 newSequenceLabel = oldSequenceLabel + self._postfix
00176 self._sequenceStack.append(cms.Sequence(visitee))
00177 if hasattr(self._process, newSequenceLabel):
00178 raise StandardError("Cloning the sequence "+self._sequenceStack[-1].label()+" would overwrite existing object." )
00179 setattr(self._process, newSequenceLabel, self._sequenceStack[-1])
00180 self._sequenceLabels.append(oldSequenceLabel)
00181 else:
00182 self._sequenceStack[-1] += visitee
00183
00184 class MassSearchParamVisitor(object):
00185 """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
00186 def __init__(self,paramName,paramSearch):
00187 self._paramName = paramName
00188 self._paramSearch = paramSearch
00189 self._modules = []
00190 def enter(self,visitee):
00191 if (hasattr(visitee,self._paramName)):
00192 if getattr(visitee,self._paramName) == self._paramSearch:
00193 self._modules.append(visitee)
00194 def leave(self,visitee):
00195 pass
00196 def modules(self):
00197 return self._modules
00198
00199
00200 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue):
00201 sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue))
00202
00203 def listModules(sequence):
00204 visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
00205 sequence.visit(visitor)
00206 return visitor.modules()
00207
00208 def listSequences(sequence):
00209 visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
00210 sequence.visit(visitor)
00211 return visitor.modules()
00212
00213 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False) :
00214 """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
00215 sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly))
00216
00217 def jetCollectionString(prefix='', algo='', type=''):
00218 """
00219 ------------------------------------------------------------------
00220 return the string of the jet collection module depending on the
00221 input vaules. The default return value will be 'patAK5CaloJets'.
00222
00223 algo : indicating the algorithm type of the jet [expected are
00224 'AK5', 'IC5', 'SC7', ...]
00225 type : indicating the type of constituents of the jet [expec-
00226 ted are 'Calo', 'PFlow', 'JPT', ...]
00227 prefix : prefix indicating the type of pat collection module (ex-
00228 pected are '', 'selected', 'clean').
00229 ------------------------------------------------------------------
00230 """
00231 if(prefix==''):
00232 jetCollectionString ='pat'
00233 else:
00234 jetCollectionString =prefix
00235 jetCollectionString+='Pat'
00236 jetCollectionString+='Jets'
00237 jetCollectionString+=algo
00238 jetCollectionString+=type
00239 return jetCollectionString
00240
00241 def contains(sequence, moduleName):
00242 """
00243 ------------------------------------------------------------------
00244 return True if a module with name 'module' is contained in the
00245 sequence with name 'sequence' and False otherwise. This version
00246 is not so nice as it also returns True for any substr of the name
00247 of a contained module.
00248
00249 sequence : sequence [e.g. process.patDefaultSequence]
00250 module : module name as a string
00251 ------------------------------------------------------------------
00252 """
00253 return not sequence.__str__().find(moduleName)==-1
00254
00255
00256
00257 def cloneProcessingSnippet(process, sequence, postfix):
00258 """
00259 ------------------------------------------------------------------
00260 copy a sequence plus the modules and sequences therein
00261 both are renamed by getting a postfix
00262 input tags are automatically adjusted
00263 ------------------------------------------------------------------
00264 """
00265 result = sequence
00266 if not postfix == "":
00267 visitor = CloneSequenceVisitor(process,sequence.label(),postfix)
00268 sequence.visit(visitor)
00269 result = visitor.clonedSequence()
00270 return result
00271
00272 if __name__=="__main__":
00273 import unittest
00274 class TestModuleCommand(unittest.TestCase):
00275 def setUp(self):
00276 """Nothing to do """
00277 pass
00278 def testCloning(self):
00279 p = cms.Process("test")
00280 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00281 p.b = cms.EDProducer("b", src=cms.InputTag("a"))
00282 p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
00283 p.s = cms.Sequence(p.a*p.b*p.c *p.a)
00284 cloneProcessingSnippet(p, p.s, "New")
00285 self.assertEqual(p.dumpPython(),'import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")\n\nprocess.a = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.c = cms.EDProducer("c",\n src = cms.InputTag("b","instance")\n)\n\n\nprocess.cNew = cms.EDProducer("c",\n src = cms.InputTag("bNew","instance")\n)\n\n\nprocess.bNew = cms.EDProducer("b",\n src = cms.InputTag("aNew")\n)\n\n\nprocess.aNew = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.b = cms.EDProducer("b",\n src = cms.InputTag("a")\n)\n\n\nprocess.s = cms.Sequence(process.a*process.b*process.c*process.a)\n\n\nprocess.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew)\n\n\n')
00286 def testContains(self):
00287 p = cms.Process("test")
00288 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00289 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00290 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00291 p.s1 = cms.Sequence(p.a*p.b*p.c)
00292 p.s2 = cms.Sequence(p.b*p.c)
00293 self.assert_( contains(p.s1, "a") )
00294 self.assert_( not contains(p.s2, "a") )
00295 def testJetCollectionString(self):
00296 self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patFooBarJets')
00297 self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatFooBarJets')
00298 def testListModules(self):
00299 p = cms.Process("test")
00300 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00301 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00302 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00303 p.s = cms.Sequence(p.a*p.b*p.c)
00304 self.assertEqual([p.a,p.b,p.c], listModules(p.s))
00305 def testMassSearchReplaceParam(self):
00306 p = cms.Process("test")
00307 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00308 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00309 p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
00310 nested = cms.PSet(src = cms.InputTag("c"))
00311 )
00312 p.s = cms.Sequence(p.a*p.b*p.c)
00313 massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
00314 self.assertEqual(cms.InputTag("a"),p.c.src)
00315 self.assertNotEqual(cms.InputTag("a"),p.c.nested.src)
00316
00317 unittest.main()