CMS 3D CMS Logo

/afs/cern.ch/work/a/aaltunda/public/www/CMSSW_5_3_13_patch3/src/PhysicsTools/PatAlgos/python/tools/helpers.py

Go to the documentation of this file.
00001 import FWCore.ParameterSet.Config as cms
00002 
00003 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
00004 ## and replacing that to a given value
00005 
00006 def applyPostfix(process, label, postfix):
00007     ''' If a module is in patDefaultSequence use the cloned module.
00008     Will crash if patDefaultSequence has not been cloned with 'postfix' beforehand'''
00009     result = None
00010     defaultLabels = __labelsInSequence(process, "patDefaultSequence", postfix)
00011     if hasattr(process, "patPF2PATSequence"):
00012         defaultLabels = __labelsInSequence(process, "patPF2PATSequence", postfix)
00013     if label in defaultLabels and hasattr(process, label+postfix):
00014         result = getattr(process, label+postfix)
00015     elif hasattr(process, label):
00016         print "WARNING: called applyPostfix for module/sequence %s which is not in patDefaultSequence%s!"%(label,postfix)
00017         result = getattr(process, label)
00018     return result
00019 
00020 def removeIfInSequence(process, target,  sequenceLabel, postfix=""):
00021     labels = __labelsInSequence(process, sequenceLabel, postfix)
00022     if target+postfix in labels:
00023         getattr(process, sequenceLabel+postfix).remove(
00024             getattr(process, target+postfix)
00025             )
00026 
00027 def __labelsInSequence(process, sequenceLabel, postfix=""):
00028     result = [ m.label()[:-len(postfix)] for m in listModules( getattr(process,sequenceLabel+postfix))]
00029     result.extend([ m.label()[:-len(postfix)] for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
00030     if postfix == "":
00031         result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
00032         result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
00033     return result
00034 
00035 class MassSearchReplaceParamVisitor(object):
00036     """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
00037     def __init__(self,paramName,paramSearch,paramValue,verbose=False):
00038         self._paramName   = paramName
00039         self._paramValue  = paramValue
00040         self._paramSearch = paramSearch
00041         self._verbose = verbose
00042     def enter(self,visitee):
00043         if (hasattr(visitee,self._paramName)):
00044             if getattr(visitee,self._paramName) == self._paramSearch:
00045                 if self._verbose:print "Replaced %s.%s: %s => %s" % (visitee,self._paramName,getattr(visitee,self._paramName),self._paramValue)
00046                 setattr(visitee,self._paramName,self._paramValue)
00047     def leave(self,visitee):
00048         pass
00049 
00050 class MassSearchReplaceAnyInputTagVisitor(object):
00051     """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
00052        It will climb down within PSets, VPSets and VInputTags to find its target"""
00053     def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False):
00054         self._paramSearch  = self.standardizeInputTagFmt(paramSearch)
00055         self._paramReplace = self.standardizeInputTagFmt(paramReplace)
00056         self._moduleName   = ''
00057         self._verbose=verbose
00058         self._moduleLabelOnly=moduleLabelOnly
00059     def doIt(self,pset,base):
00060         if isinstance(pset, cms._Parameterizable):
00061             for name in pset.parameterNames_():
00062                 # if I use pset.parameters_().items() I get copies of the parameter values
00063                 # so I can't modify the nested pset
00064                 value = getattr(pset,name)
00065                 type = value.pythonTypeName()
00066                 if type == 'cms.PSet':
00067                     self.doIt(value,base+"."+name)
00068                 elif type == 'cms.VPSet':
00069                     for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
00070                 elif type == 'cms.VInputTag':
00071                     for (i,n) in enumerate(value):
00072                          # VInputTag can be declared as a list of strings, so ensure that n is formatted correctly
00073                          n = self.standardizeInputTagFmt(n)
00074                          if (n == self._paramSearch):
00075                             if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace)
00076                             value[i] = self._paramReplace
00077                          elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
00078                             nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
00079                             if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep)
00080                             value[i] = nrep
00081                 elif type.endswith('.InputTag'):
00082                     if value == self._paramSearch:
00083                         if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace)
00084                         from copy import deepcopy
00085                         if 'untracked' in type:
00086                             setattr(pset, name, cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
00087                                                                        self._paramReplace.getProductInstanceLabel(),
00088                                                                        self._paramReplace.getProcessName()))
00089                         else:
00090                             setattr(pset, name, deepcopy(self._paramReplace) )
00091                     elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
00092                         from copy import deepcopy
00093                         repl = deepcopy(getattr(pset, name))
00094                         repl.moduleLabel = self._paramReplace.moduleLabel
00095                         setattr(pset, name, repl)
00096                         if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, value, repl)
00097 
00098 
00099     @staticmethod
00100     def standardizeInputTagFmt(inputTag):
00101        ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
00102        if not isinstance(inputTag, cms.InputTag):
00103           return cms.InputTag(inputTag)
00104        return inputTag
00105 
00106     def enter(self,visitee):
00107         label = ''
00108         try:    label = visitee.label_()
00109         except AttributeError: label = '<Module not in a Process>'
00110         self.doIt(visitee, label)
00111     def leave(self,visitee):
00112         pass
00113 
00114 #FIXME name is not generic enough now
00115 class GatherAllModulesVisitor(object):
00116     """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
00117     def __init__(self, gatheredInstance=cms._Module):
00118         self._modules = []
00119         self._gatheredInstance= gatheredInstance
00120     def enter(self,visitee):
00121         if isinstance(visitee,self._gatheredInstance):
00122             self._modules.append(visitee)
00123     def leave(self,visitee):
00124         pass
00125     def modules(self):
00126         return self._modules
00127 
00128 class CloneSequenceVisitor(object):
00129     """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
00130     All modules and sequences are cloned and a postfix is added"""
00131     def __init__(self, process, label, postfix):
00132         self._process = process
00133         self._postfix = postfix
00134         self._sequenceStack = [label]
00135         self._moduleLabels = []
00136         self._sequenceLabels = []
00137         self._waitForSequenceToClose = None # modules will only be cloned or added if this is None
00138 
00139     def enter(self,visitee):
00140         if not self._waitForSequenceToClose is None:
00141             return #we are in a already cloned sequence
00142         if isinstance(visitee,cms._Module):
00143             label = visitee.label()
00144             newModule = None
00145             if label in self._moduleLabels:
00146                 newModule = getattr(self._process, label+self._postfix)
00147             else:
00148                 self._moduleLabels.append(label)
00149 
00150                 newModule = visitee.clone()
00151                 setattr(self._process, label+self._postfix, newModule)
00152             self.__appendToTopSequence(newModule)
00153 
00154         if isinstance(visitee,cms.Sequence):
00155             if visitee.label() in self._sequenceLabels: # is the sequence allready cloned?
00156                 self._waitForSequenceToClose = visitee.label()
00157                 self._sequenceStack.append(  getattr(self._process, visitee.label()+self._postfix) )
00158             else:
00159                 self._sequenceStack.append(visitee.label())#save desired label as placeholder until we have a module to create the sequence
00160 
00161     def leave(self,visitee):
00162         if isinstance(visitee,cms.Sequence):
00163             if self._waitForSequenceToClose == visitee.label():
00164                 self._waitForSequenceToClose = None
00165             if not isinstance(self._sequenceStack[-1], cms.Sequence):
00166                 raise StandardError, "empty Sequence encountered during cloneing. sequnece stack: %s"%self._sequenceStack
00167             self.__appendToTopSequence( self._sequenceStack.pop() )
00168 
00169     def clonedSequence(self):
00170         if not len(self._sequenceStack) == 1:
00171             raise StandardError, "someting went wrong, the sequence stack looks like: %s"%self._sequenceStack
00172         for label in self._moduleLabels:
00173             massSearchReplaceAnyInputTag(self._sequenceStack[-1], label, label+self._postfix, moduleLabelOnly=True, verbose=False)
00174         self._moduleLabels = [] #prevent the InputTag replacement next time this is called.
00175         return self._sequenceStack[-1]
00176 
00177     def __appendToTopSequence(self, visitee):#this is darn ugly because empty cms.Sequences are not supported
00178         if isinstance(self._sequenceStack[-1], basestring):#we have the name of an empty sequence on the stack. create it!
00179             oldSequenceLabel = self._sequenceStack.pop()
00180             newSequenceLabel = oldSequenceLabel + self._postfix
00181             self._sequenceStack.append(cms.Sequence(visitee))
00182             if hasattr(self._process, newSequenceLabel):
00183                 raise StandardError("Cloning the sequence "+self._sequenceStack[-1].label()+" would overwrite existing object." )
00184             setattr(self._process, newSequenceLabel, self._sequenceStack[-1])
00185             self._sequenceLabels.append(oldSequenceLabel)
00186         else:
00187             self._sequenceStack[-1] += visitee
00188 
00189 class MassSearchParamVisitor(object):
00190     """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
00191     def __init__(self,paramName,paramSearch):
00192         self._paramName   = paramName
00193         self._paramSearch = paramSearch
00194         self._modules = []
00195     def enter(self,visitee):
00196         if (hasattr(visitee,self._paramName)):
00197             if getattr(visitee,self._paramName) == self._paramSearch:
00198                 self._modules.append(visitee)
00199     def leave(self,visitee):
00200         pass
00201     def modules(self):
00202         return self._modules
00203 
00204 
00205 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
00206     sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
00207 
00208 def listModules(sequence):
00209     visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
00210     sequence.visit(visitor)
00211     return visitor.modules()
00212 
00213 def listSequences(sequence):
00214     visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
00215     sequence.visit(visitor)
00216     return visitor.modules()
00217 
00218 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False) :
00219     """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
00220     sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly))
00221 
00222 def jetCollectionString(prefix='', algo='', type=''):
00223     """
00224     ------------------------------------------------------------------
00225     return the string of the jet collection module depending on the
00226     input vaules. The default return value will be 'patAK5CaloJets'.
00227 
00228     algo   : indicating the algorithm type of the jet [expected are
00229              'AK5', 'IC5', 'SC7', ...]
00230     type   : indicating the type of constituents of the jet [expec-
00231              ted are 'Calo', 'PFlow', 'JPT', ...]
00232     prefix : prefix indicating the type of pat collection module (ex-
00233              pected are '', 'selected', 'clean').
00234     ------------------------------------------------------------------
00235     """
00236     if(prefix==''):
00237         jetCollectionString ='pat'
00238     else:
00239         jetCollectionString =prefix
00240         jetCollectionString+='Pat'
00241     jetCollectionString+='Jets'
00242     jetCollectionString+=algo
00243     jetCollectionString+=type
00244     return jetCollectionString
00245 
00246 def contains(sequence, moduleName):
00247     """
00248     ------------------------------------------------------------------
00249     return True if a module with name 'module' is contained in the
00250     sequence with name 'sequence' and False otherwise. This version
00251     is not so nice as it also returns True for any substr of the name
00252     of a contained module.
00253 
00254     sequence : sequence [e.g. process.patDefaultSequence]
00255     module   : module name as a string
00256     ------------------------------------------------------------------
00257     """
00258     return not sequence.__str__().find(moduleName)==-1
00259 
00260 
00261 
00262 def cloneProcessingSnippet(process, sequence, postfix):
00263    """
00264    ------------------------------------------------------------------
00265    copy a sequence plus the modules and sequences therein
00266    both are renamed by getting a postfix
00267    input tags are automatically adjusted
00268    ------------------------------------------------------------------
00269    """
00270    result = sequence
00271    if not postfix == "":
00272        visitor = CloneSequenceVisitor(process,sequence.label(),postfix)
00273        sequence.visit(visitor)
00274        result = visitor.clonedSequence()
00275    return result
00276 
00277 if __name__=="__main__":
00278    import unittest
00279    class TestModuleCommand(unittest.TestCase):
00280        def setUp(self):
00281            """Nothing to do """
00282            pass
00283        def testCloning(self):
00284            p = cms.Process("test")
00285            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00286            p.b = cms.EDProducer("b", src=cms.InputTag("a"))
00287            p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
00288            p.s = cms.Sequence(p.a*p.b*p.c *p.a)
00289            cloneProcessingSnippet(p, p.s, "New")
00290            self.assertEqual(p.dumpPython(),'import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")\n\nprocess.a = cms.EDProducer("a",\n    src = cms.InputTag("gen")\n)\n\n\nprocess.c = cms.EDProducer("c",\n    src = cms.InputTag("b","instance")\n)\n\n\nprocess.cNew = cms.EDProducer("c",\n    src = cms.InputTag("bNew","instance")\n)\n\n\nprocess.bNew = cms.EDProducer("b",\n    src = cms.InputTag("aNew")\n)\n\n\nprocess.aNew = cms.EDProducer("a",\n    src = cms.InputTag("gen")\n)\n\n\nprocess.b = cms.EDProducer("b",\n    src = cms.InputTag("a")\n)\n\n\nprocess.s = cms.Sequence(process.a*process.b*process.c*process.a)\n\n\nprocess.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew)\n\n\n')
00291        def testContains(self):
00292            p = cms.Process("test")
00293            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00294            p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00295            p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00296            p.s1 = cms.Sequence(p.a*p.b*p.c)
00297            p.s2 = cms.Sequence(p.b*p.c)
00298            self.assert_( contains(p.s1, "a") )
00299            self.assert_( not contains(p.s2, "a") )
00300        def testJetCollectionString(self):
00301            self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patFooBarJets')
00302            self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatFooBarJets')
00303        def testListModules(self):
00304            p = cms.Process("test")
00305            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00306            p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00307            p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00308            p.s = cms.Sequence(p.a*p.b*p.c)
00309            self.assertEqual([p.a,p.b,p.c], listModules(p.s))
00310        def testMassSearchReplaceParam(self):
00311            p = cms.Process("test")
00312            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00313            p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00314            p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
00315                                 nested = cms.PSet(src = cms.InputTag("c"))
00316                                )
00317            p.s = cms.Sequence(p.a*p.b*p.c)
00318            massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
00319            self.assertEqual(cms.InputTag("a"),p.c.src)
00320            self.assertNotEqual(cms.InputTag("a"),p.c.nested.src)
00321        def testMassSearchReplaceAnyInputTag(self):
00322            p = cms.Process("test")
00323            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00324            p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00325            p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
00326                                 nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c")),
00327                                 nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("b")), cms.PSet(src = cms.InputTag("d"))),
00328                                 vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d"))
00329                                )
00330            p.s = cms.Sequence(p.a*p.b*p.c)
00331            massSearchReplaceAnyInputTag(p.s, cms.InputTag("b"), cms.InputTag("new"))
00332            self.assertNotEqual(cms.InputTag("new"), p.b.src)
00333            self.assertEqual(cms.InputTag("new"), p.c.src)
00334            self.assertEqual(cms.InputTag("new"), p.c.nested.src)
00335            self.assertEqual(cms.InputTag("new"), p.c.nested.src)
00336            self.assertNotEqual(cms.InputTag("new"), p.c.nested.src2)
00337            self.assertEqual(cms.InputTag("new"), p.c.nestedv[0].src)
00338            self.assertNotEqual(cms.InputTag("new"), p.c.nestedv[1].src)
00339            self.assertNotEqual(cms.InputTag("new"), p.c.vec[0])
00340            self.assertEqual(cms.InputTag("new"), p.c.vec[1])
00341            self.assertNotEqual(cms.InputTag("new"), p.c.vec[2])
00342            self.assertNotEqual(cms.InputTag("new"), p.c.vec[3])
00343 
00344 
00345    unittest.main()