00001 import FWCore.ParameterSet.Config as cms
00002
00003
00004
00005
00006 def applyPostfix(process, label, postfix):
00007 ''' If a module is in patDefaultSequence use the cloned module.
00008 Will crash if patDefaultSequence has not been cloned with 'postfix' beforehand'''
00009 result = None
00010 defaultLabels = __labelsInSequence(process, "patDefaultSequence", postfix)
00011 if hasattr(process, "patPF2PATSequence"):
00012 defaultLabels = __labelsInSequence(process, "patPF2PATSequence", postfix)
00013 if label in defaultLabels and hasattr(process, label+postfix):
00014 result = getattr(process, label+postfix)
00015 elif hasattr(process, label):
00016 print "WARNING: called applyPostfix for module/sequence %s which is not in patDefaultSequence%s!"%(label,postfix)
00017 result = getattr(process, label)
00018 return result
00019
00020 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
00021 labels = __labelsInSequence(process, sequenceLabel, postfix)
00022 if target+postfix in labels:
00023 getattr(process, sequenceLabel+postfix).remove(
00024 getattr(process, target+postfix)
00025 )
00026
00027 def __labelsInSequence(process, sequenceLabel, postfix=""):
00028 result = [ m.label()[:-len(postfix)] for m in listModules( getattr(process,sequenceLabel+postfix))]
00029 result.extend([ m.label()[:-len(postfix)] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
00030 if postfix == "":
00031 result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
00032 result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
00033 return result
00034
00035 class MassSearchReplaceParamVisitor(object):
00036 """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
00037 def __init__(self,paramName,paramSearch,paramValue,verbose=False):
00038 self._paramName = paramName
00039 self._paramValue = paramValue
00040 self._paramSearch = paramSearch
00041 self._verbose = verbose
00042 def enter(self,visitee):
00043 if (hasattr(visitee,self._paramName)):
00044 if getattr(visitee,self._paramName) == self._paramSearch:
00045 if self._verbose:print "Replaced %s.%s: %s => %s" % (visitee,self._paramName,getattr(visitee,self._paramName),self._paramValue)
00046 setattr(visitee,self._paramName,self._paramValue)
00047 def leave(self,visitee):
00048 pass
00049
00050 class MassSearchReplaceAnyInputTagVisitor(object):
00051 """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
00052 It will climb down within PSets, VPSets and VInputTags to find its target"""
00053 def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False):
00054 self._paramSearch = self.standardizeInputTagFmt(paramSearch)
00055 self._paramReplace = self.standardizeInputTagFmt(paramReplace)
00056 self._moduleName = ''
00057 self._verbose=verbose
00058 self._moduleLabelOnly=moduleLabelOnly
00059 def doIt(self,pset,base):
00060 if isinstance(pset, cms._Parameterizable):
00061 for name in pset.parameters_().keys():
00062
00063
00064 value = getattr(pset,name)
00065 type = value.pythonTypeName()
00066 if type == 'cms.PSet':
00067 self.doIt(value,base+"."+name)
00068 elif type == 'cms.VPSet':
00069 for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
00070 elif type == 'cms.VInputTag':
00071 for (i,n) in enumerate(value):
00072
00073 n = self.standardizeInputTagFmt(n)
00074 if (n == self._paramSearch):
00075 if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace)
00076 value[i] = self._paramReplace
00077 elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
00078 nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
00079 if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep)
00080 value[i] = nrep
00081 elif type.endswith('.InputTag'):
00082 if value == self._paramSearch:
00083 if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace)
00084 from copy import deepcopy
00085 if 'untracked' in type:
00086 setattr(pset, name, cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
00087 self._paramReplace.getProductInstanceLabel(),
00088 self._paramReplace.getProcessName()))
00089 else:
00090 setattr(pset, name, deepcopy(self._paramReplace) )
00091 elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
00092 from copy import deepcopy
00093 repl = deepcopy(getattr(pset, name))
00094 repl.moduleLabel = self._paramReplace.moduleLabel
00095 setattr(pset, name, repl)
00096 if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, value, repl)
00097
00098
00099 @staticmethod
00100 def standardizeInputTagFmt(inputTag):
00101 ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
00102 if not isinstance(inputTag, cms.InputTag):
00103 return cms.InputTag(inputTag)
00104 return inputTag
00105
00106 def enter(self,visitee):
00107 label = ''
00108 try: label = visitee.label_()
00109 except AttributeError: label = '<Module not in a Process>'
00110 self.doIt(visitee, label)
00111 def leave(self,visitee):
00112 pass
00113
00114
00115 class GatherAllModulesVisitor(object):
00116 """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
00117 def __init__(self, gatheredInstance=cms._Module):
00118 self._modules = []
00119 self._gatheredInstance= gatheredInstance
00120 def enter(self,visitee):
00121 if isinstance(visitee,self._gatheredInstance):
00122 self._modules.append(visitee)
00123 def leave(self,visitee):
00124 pass
00125 def modules(self):
00126 return self._modules
00127
00128 class CloneSequenceVisitor(object):
00129 """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
00130 All modules and sequences are cloned and a postfix is added"""
00131 def __init__(self, process, label, postfix, removePostfix=""):
00132 self._process = process
00133 self._postfix = postfix
00134 self._removePostfix = removePostfix
00135 self._moduleLabels = []
00136 self._clonedSequence = cms.Sequence()
00137 setattr(process, self._newLabel(label), self._clonedSequence)
00138
00139 def enter(self, visitee):
00140 if isinstance(visitee, cms._Module):
00141 label = visitee.label()
00142 newModule = None
00143 if label in self._moduleLabels:
00144 newModule = getattr(self._process, self._newLabel(label))
00145 else:
00146 self._moduleLabels.append(label)
00147 newModule = visitee.clone()
00148 setattr(self._process, self._newLabel(label), newModule)
00149 self.__appendToTopSequence(newModule)
00150
00151 def leave(self, visitee):
00152 pass
00153
00154 def clonedSequence(self):
00155 for label in self._moduleLabels:
00156 massSearchReplaceAnyInputTag(self._clonedSequence, label, self._newLabel(label), moduleLabelOnly=True, verbose=False)
00157 self._moduleLabels = []
00158 return self._clonedSequence
00159
00160 def _newLabel(self, label):
00161 if self._removePostfix != "":
00162 if label[-len(self._removePostfix):] == self._removePostfix:
00163 label = label[0:-len(self._removePostfix)]
00164 else:
00165 raise StandardError("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
00166 return label + self._postfix
00167
00168 def __appendToTopSequence(self, visitee):
00169 self._clonedSequence += visitee
00170
00171 class MassSearchParamVisitor(object):
00172 """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
00173 def __init__(self,paramName,paramSearch):
00174 self._paramName = paramName
00175 self._paramSearch = paramSearch
00176 self._modules = []
00177 def enter(self,visitee):
00178 if (hasattr(visitee,self._paramName)):
00179 if getattr(visitee,self._paramName) == self._paramSearch:
00180 self._modules.append(visitee)
00181 def leave(self,visitee):
00182 pass
00183 def modules(self):
00184 return self._modules
00185
00186
00187 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
00188 sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
00189
00190 def listModules(sequence):
00191 visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
00192 sequence.visit(visitor)
00193 return visitor.modules()
00194
00195 def listSequences(sequence):
00196 visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
00197 sequence.visit(visitor)
00198 return visitor.modules()
00199
00200 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False) :
00201 """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
00202 sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly))
00203
00204 def jetCollectionString(prefix='', algo='', type=''):
00205 """
00206 ------------------------------------------------------------------
00207 return the string of the jet collection module depending on the
00208 input vaules. The default return value will be 'patAK5CaloJets'.
00209
00210 algo : indicating the algorithm type of the jet [expected are
00211 'AK5', 'IC5', 'SC7', ...]
00212 type : indicating the type of constituents of the jet [expec-
00213 ted are 'Calo', 'PFlow', 'JPT', ...]
00214 prefix : prefix indicating the type of pat collection module (ex-
00215 pected are '', 'selected', 'clean').
00216 ------------------------------------------------------------------
00217 """
00218 if(prefix==''):
00219 jetCollectionString ='pat'
00220 else:
00221 jetCollectionString =prefix
00222 jetCollectionString+='Pat'
00223 jetCollectionString+='Jets'
00224 jetCollectionString+=algo
00225 jetCollectionString+=type
00226 return jetCollectionString
00227
00228 def contains(sequence, moduleName):
00229 """
00230 ------------------------------------------------------------------
00231 return True if a module with name 'module' is contained in the
00232 sequence with name 'sequence' and False otherwise. This version
00233 is not so nice as it also returns True for any substr of the name
00234 of a contained module.
00235
00236 sequence : sequence [e.g. process.patDefaultSequence]
00237 module : module name as a string
00238 ------------------------------------------------------------------
00239 """
00240 return not sequence.__str__().find(moduleName)==-1
00241
00242
00243
00244 def cloneProcessingSnippet(process, sequence, postfix, removePostfix=""):
00245 """
00246 ------------------------------------------------------------------
00247 copy a sequence plus the modules and sequences therein
00248 both are renamed by getting a postfix
00249 input tags are automatically adjusted
00250 ------------------------------------------------------------------
00251 """
00252 result = sequence
00253 if not postfix == "":
00254 visitor = CloneSequenceVisitor(process, sequence.label(), postfix, removePostfix)
00255 sequence.visit(visitor)
00256 result = visitor.clonedSequence()
00257 return result
00258
00259 if __name__=="__main__":
00260 import unittest
00261 class TestModuleCommand(unittest.TestCase):
00262 def setUp(self):
00263 """Nothing to do """
00264 pass
00265 def testCloning(self):
00266 p = cms.Process("test")
00267 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00268 p.b = cms.EDProducer("b", src=cms.InputTag("a"))
00269 p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
00270 p.s = cms.Sequence(p.a*p.b*p.c *p.a)
00271 cloneProcessingSnippet(p, p.s, "New")
00272 self.assertEqual(p.dumpPython(),'import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")\n\nprocess.a = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.c = cms.EDProducer("c",\n src = cms.InputTag("b","instance")\n)\n\n\nprocess.cNew = cms.EDProducer("c",\n src = cms.InputTag("bNew","instance")\n)\n\n\nprocess.bNew = cms.EDProducer("b",\n src = cms.InputTag("aNew")\n)\n\n\nprocess.aNew = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.b = cms.EDProducer("b",\n src = cms.InputTag("a")\n)\n\n\nprocess.s = cms.Sequence(process.a*process.b*process.c*process.a)\n\n\nprocess.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew)\n\n\n')
00273 def testContains(self):
00274 p = cms.Process("test")
00275 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00276 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00277 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00278 p.s1 = cms.Sequence(p.a*p.b*p.c)
00279 p.s2 = cms.Sequence(p.b*p.c)
00280 self.assert_( contains(p.s1, "a") )
00281 self.assert_( not contains(p.s2, "a") )
00282 def testJetCollectionString(self):
00283 self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patFooBarJets')
00284 self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatFooBarJets')
00285 def testListModules(self):
00286 p = cms.Process("test")
00287 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00288 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00289 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00290 p.s = cms.Sequence(p.a*p.b*p.c)
00291 self.assertEqual([p.a,p.b,p.c], listModules(p.s))
00292 def testMassSearchReplaceParam(self):
00293 p = cms.Process("test")
00294 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00295 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00296 p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
00297 nested = cms.PSet(src = cms.InputTag("c"))
00298 )
00299 p.s = cms.Sequence(p.a*p.b*p.c)
00300 massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
00301 self.assertEqual(cms.InputTag("a"),p.c.src)
00302 self.assertNotEqual(cms.InputTag("a"),p.c.nested.src)
00303
00304 unittest.main()