00001 import FWCore.ParameterSet.Config as cms
00002
00003
00004
00005
00006 def applyPostfix(process, label, postfix):
00007 ''' If a module is in patDefaultSequence use the cloned module.
00008 Will crash if patDefaultSequence has not been cloned with 'postfix' beforehand'''
00009 result = None
00010 defaultLabels = __labelsInSequence(process, "patDefaultSequence", postfix)
00011 if hasattr(process, "patPF2PATSequence"):
00012 defaultLabels = __labelsInSequence(process, "patPF2PATSequence", postfix)
00013 if label in defaultLabels and hasattr(process, label+postfix):
00014 result = getattr(process, label+postfix)
00015 elif hasattr(process, label):
00016 print "WARNING: called applyPostfix for module/sequence %s which is not in patDefaultSequence%s!"%(label,postfix)
00017 result = getattr(process, label)
00018 return result
00019
00020 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
00021 labels = __labelsInSequence(process, sequenceLabel, postfix)
00022 if target+postfix in labels:
00023 getattr(process, sequenceLabel+postfix).remove(
00024 getattr(process, target+postfix)
00025 )
00026
00027 def __labelsInSequence(process, sequenceLabel, postfix=""):
00028 result = [ m.label()[:-len(postfix)] for m in listModules( getattr(process,sequenceLabel+postfix))]
00029 result.extend([ m.label()[:-len(postfix)] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
00030 if postfix == "":
00031 result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
00032 result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
00033 return result
00034
00035 class MassSearchReplaceParamVisitor(object):
00036 """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
00037 def __init__(self,paramName,paramSearch,paramValue,verbose=False):
00038 self._paramName = paramName
00039 self._paramValue = paramValue
00040 self._paramSearch = paramSearch
00041 self._verbose = verbose
00042 def enter(self,visitee):
00043 if (hasattr(visitee,self._paramName)):
00044 if getattr(visitee,self._paramName) == self._paramSearch:
00045 if self._verbose:print "Replaced %s.%s: %s => %s" % (visitee,self._paramName,getattr(visitee,self._paramName),self._paramValue)
00046 setattr(visitee,self._paramName,self._paramValue)
00047 def leave(self,visitee):
00048 pass
00049
00050 class MassSearchReplaceAnyInputTagVisitor(object):
00051 """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
00052 It will climb down within PSets, VPSets and VInputTags to find its target"""
00053 def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False):
00054 self._paramSearch = self.standardizeInputTagFmt(paramSearch)
00055 self._paramReplace = self.standardizeInputTagFmt(paramReplace)
00056 self._moduleName = ''
00057 self._verbose=verbose
00058 self._moduleLabelOnly=moduleLabelOnly
00059 def doIt(self,pset,base):
00060 if isinstance(pset, cms._Parameterizable):
00061 for name in pset.parameters_().keys():
00062
00063
00064 value = getattr(pset,name)
00065 type = value.pythonTypeName()
00066 if type == 'cms.PSet':
00067 self.doIt(value,base+"."+name)
00068 elif type == 'cms.VPSet':
00069 for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
00070 elif type == 'cms.VInputTag':
00071 for (i,n) in enumerate(value):
00072
00073 n = self.standardizeInputTagFmt(n)
00074 if (n == self._paramSearch):
00075 if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace)
00076 value[i] = self._paramReplace
00077 elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
00078 nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
00079 if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep)
00080 value[i] = nrep
00081 elif type.endswith('.InputTag'):
00082 if value == self._paramSearch:
00083 if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace)
00084 from copy import deepcopy
00085 if 'untracked' in type:
00086 setattr(pset, name, cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
00087 self._paramReplace.getProductInstanceLabel(),
00088 self._paramReplace.getProcessName()))
00089 else:
00090 setattr(pset, name, deepcopy(self._paramReplace) )
00091 elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
00092 from copy import deepcopy
00093 repl = deepcopy(getattr(pset, name))
00094 repl.moduleLabel = self._paramReplace.moduleLabel
00095 setattr(pset, name, repl)
00096 if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, value, repl)
00097
00098
00099 @staticmethod
00100 def standardizeInputTagFmt(inputTag):
00101 ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
00102 if not isinstance(inputTag, cms.InputTag):
00103 return cms.InputTag(inputTag)
00104 return inputTag
00105
00106 def enter(self,visitee):
00107 label = ''
00108 try: label = visitee.label_()
00109 except AttributeError: label = '<Module not in a Process>'
00110 self.doIt(visitee, label)
00111 def leave(self,visitee):
00112 pass
00113
00114
00115 class GatherAllModulesVisitor(object):
00116 """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
00117 def __init__(self, gatheredInstance=cms._Module):
00118 self._modules = []
00119 self._gatheredInstance= gatheredInstance
00120 def enter(self,visitee):
00121 if isinstance(visitee,self._gatheredInstance):
00122 self._modules.append(visitee)
00123 def leave(self,visitee):
00124 pass
00125 def modules(self):
00126 return self._modules
00127
00128 class CloneSequenceVisitor(object):
00129 """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
00130 All modules and sequences are cloned and a postfix is added"""
00131 def __init__(self, process, label, postfix):
00132 self._process = process
00133 self._postfix = postfix
00134 self._sequenceStack = [label]
00135 self._moduleLabels = []
00136 self._sequenceLabels = []
00137 self._waitForSequenceToClose = None
00138
00139 def enter(self,visitee):
00140 if not self._waitForSequenceToClose is None:
00141 return
00142 if isinstance(visitee,cms._Module):
00143 label = visitee.label()
00144 newModule = None
00145 if label in self._moduleLabels:
00146 newModule = getattr(self._process, label+self._postfix)
00147 else:
00148 self._moduleLabels.append(label)
00149
00150 newModule = visitee.clone()
00151 setattr(self._process, label+self._postfix, newModule)
00152 self.__appendToTopSequence(newModule)
00153
00154 if isinstance(visitee,cms.Sequence):
00155 if visitee.label() in self._sequenceLabels:
00156 self._waitForSequenceToClose = visitee.label()
00157 self._sequenceStack.append( getattr(self._process, visitee.label()+self._postfix) )
00158 else:
00159 self._sequenceStack.append(visitee.label())
00160
00161 def leave(self,visitee):
00162 if isinstance(visitee,cms.Sequence):
00163 if self._waitForSequenceToClose == visitee.label():
00164 self._waitForSequenceToClose = None
00165 if not isinstance(self._sequenceStack[-1], cms.Sequence):
00166 raise StandardError, "empty Sequence encountered during cloneing. sequnece stack: %s"%self._sequenceStack
00167 self.__appendToTopSequence( self._sequenceStack.pop() )
00168
00169 def clonedSequence(self):
00170 if not len(self._sequenceStack) == 1:
00171 raise StandardError, "someting went wrong, the sequence stack looks like: %s"%self._sequenceStack
00172 for label in self._moduleLabels:
00173 massSearchReplaceAnyInputTag(self._sequenceStack[-1], label, label+self._postfix, moduleLabelOnly=True, verbose=False)
00174 self._moduleLabels = []
00175 return self._sequenceStack[-1]
00176
00177 def __appendToTopSequence(self, visitee):
00178 if isinstance(self._sequenceStack[-1], basestring):
00179 oldSequenceLabel = self._sequenceStack.pop()
00180 newSequenceLabel = oldSequenceLabel + self._postfix
00181 self._sequenceStack.append(cms.Sequence(visitee))
00182 if hasattr(self._process, newSequenceLabel):
00183 raise StandardError("Cloning the sequence "+self._sequenceStack[-1].label()+" would overwrite existing object." )
00184 setattr(self._process, newSequenceLabel, self._sequenceStack[-1])
00185 self._sequenceLabels.append(oldSequenceLabel)
00186 else:
00187 self._sequenceStack[-1] += visitee
00188
00189 class MassSearchParamVisitor(object):
00190 """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
00191 def __init__(self,paramName,paramSearch):
00192 self._paramName = paramName
00193 self._paramSearch = paramSearch
00194 self._modules = []
00195 def enter(self,visitee):
00196 if (hasattr(visitee,self._paramName)):
00197 if getattr(visitee,self._paramName) == self._paramSearch:
00198 self._modules.append(visitee)
00199 def leave(self,visitee):
00200 pass
00201 def modules(self):
00202 return self._modules
00203
00204
00205 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
00206 sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
00207
00208 def listModules(sequence):
00209 visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
00210 sequence.visit(visitor)
00211 return visitor.modules()
00212
00213 def listSequences(sequence):
00214 visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
00215 sequence.visit(visitor)
00216 return visitor.modules()
00217
00218 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False) :
00219 """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
00220 sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly))
00221
00222 def jetCollectionString(prefix='', algo='', type=''):
00223 """
00224 ------------------------------------------------------------------
00225 return the string of the jet collection module depending on the
00226 input vaules. The default return value will be 'patAK5CaloJets'.
00227
00228 algo : indicating the algorithm type of the jet [expected are
00229 'AK5', 'IC5', 'SC7', ...]
00230 type : indicating the type of constituents of the jet [expec-
00231 ted are 'Calo', 'PFlow', 'JPT', ...]
00232 prefix : prefix indicating the type of pat collection module (ex-
00233 pected are '', 'selected', 'clean').
00234 ------------------------------------------------------------------
00235 """
00236 if(prefix==''):
00237 jetCollectionString ='pat'
00238 else:
00239 jetCollectionString =prefix
00240 jetCollectionString+='Pat'
00241 jetCollectionString+='Jets'
00242 jetCollectionString+=algo
00243 jetCollectionString+=type
00244 return jetCollectionString
00245
00246 def contains(sequence, moduleName):
00247 """
00248 ------------------------------------------------------------------
00249 return True if a module with name 'module' is contained in the
00250 sequence with name 'sequence' and False otherwise. This version
00251 is not so nice as it also returns True for any substr of the name
00252 of a contained module.
00253
00254 sequence : sequence [e.g. process.patDefaultSequence]
00255 module : module name as a string
00256 ------------------------------------------------------------------
00257 """
00258 return not sequence.__str__().find(moduleName)==-1
00259
00260
00261
00262 def cloneProcessingSnippet(process, sequence, postfix):
00263 """
00264 ------------------------------------------------------------------
00265 copy a sequence plus the modules and sequences therein
00266 both are renamed by getting a postfix
00267 input tags are automatically adjusted
00268 ------------------------------------------------------------------
00269 """
00270 result = sequence
00271 if not postfix == "":
00272 visitor = CloneSequenceVisitor(process,sequence.label(),postfix)
00273 sequence.visit(visitor)
00274 result = visitor.clonedSequence()
00275 return result
00276
00277 if __name__=="__main__":
00278 import unittest
00279 class TestModuleCommand(unittest.TestCase):
00280 def setUp(self):
00281 """Nothing to do """
00282 pass
00283 def testCloning(self):
00284 p = cms.Process("test")
00285 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00286 p.b = cms.EDProducer("b", src=cms.InputTag("a"))
00287 p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
00288 p.s = cms.Sequence(p.a*p.b*p.c *p.a)
00289 cloneProcessingSnippet(p, p.s, "New")
00290 self.assertEqual(p.dumpPython(),'import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")\n\nprocess.a = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.c = cms.EDProducer("c",\n src = cms.InputTag("b","instance")\n)\n\n\nprocess.cNew = cms.EDProducer("c",\n src = cms.InputTag("bNew","instance")\n)\n\n\nprocess.bNew = cms.EDProducer("b",\n src = cms.InputTag("aNew")\n)\n\n\nprocess.aNew = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.b = cms.EDProducer("b",\n src = cms.InputTag("a")\n)\n\n\nprocess.s = cms.Sequence(process.a*process.b*process.c*process.a)\n\n\nprocess.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew)\n\n\n')
00291 def testContains(self):
00292 p = cms.Process("test")
00293 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00294 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00295 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00296 p.s1 = cms.Sequence(p.a*p.b*p.c)
00297 p.s2 = cms.Sequence(p.b*p.c)
00298 self.assert_( contains(p.s1, "a") )
00299 self.assert_( not contains(p.s2, "a") )
00300 def testJetCollectionString(self):
00301 self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patFooBarJets')
00302 self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatFooBarJets')
00303 def testListModules(self):
00304 p = cms.Process("test")
00305 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00306 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00307 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00308 p.s = cms.Sequence(p.a*p.b*p.c)
00309 self.assertEqual([p.a,p.b,p.c], listModules(p.s))
00310 def testMassSearchReplaceParam(self):
00311 p = cms.Process("test")
00312 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00313 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00314 p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
00315 nested = cms.PSet(src = cms.InputTag("c"))
00316 )
00317 p.s = cms.Sequence(p.a*p.b*p.c)
00318 massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
00319 self.assertEqual(cms.InputTag("a"),p.c.src)
00320 self.assertNotEqual(cms.InputTag("a"),p.c.nested.src)
00321
00322 unittest.main()