CMS 3D CMS Logo

/data/doxygen/doxygen-1.7.3/gen/CMSSW_4_2_8/src/PhysicsTools/PatAlgos/python/tools/helpers.py

Go to the documentation of this file.
00001 import FWCore.ParameterSet.Config as cms
00002 
00003 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
00004 ## and replacing that to a given value
00005 
00006 def applyPostfix(process, label, postfix):
00007     ''' If a module is in patDefaultSequence use the cloned module.
00008     Will crash if patDefaultSequence has not been cloned with 'postfix' beforehand'''
00009     result = None 
00010     defaultLabels = __labelsInSequence(process, "patDefaultSequence", postfix)
00011     if hasattr(process, "patPF2PATSequence"):
00012         defaultLabels = __labelsInSequence(process, "patPF2PATSequence", postfix)
00013     if label in defaultLabels and hasattr(process, label+postfix):
00014         result = getattr(process, label+postfix)
00015     elif hasattr(process, label):
00016         print "WARNING: called applyPostfix for module/sequence %s which is not in patDefaultSequence%s!"%(label,postfix)
00017         result = getattr(process, label)    
00018     return result
00019 
00020 def removeIfInSequence(process, target,  sequenceLabel, postfix=""):
00021     labels = __labelsInSequence(process, sequenceLabel, postfix)
00022     if target+postfix in labels: 
00023         getattr(process, sequenceLabel+postfix).remove(
00024             getattr(process, target+postfix)
00025             )
00026     
00027 def __labelsInSequence(process, sequenceLabel, postfix=""):
00028     result = [ m.label()[:-len(postfix)] for m in listModules( getattr(process,sequenceLabel+postfix))]
00029     result.extend([ m.label()[:-len(postfix)] for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
00030     if postfix == "":  
00031         result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
00032         result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
00033     return result
00034     
00035 class MassSearchReplaceParamVisitor(object):
00036     """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
00037     def __init__(self,paramName,paramSearch,paramValue,verbose=False):
00038         self._paramName   = paramName
00039         self._paramValue  = paramValue
00040         self._paramSearch = paramSearch
00041         self._verbose = verbose
00042     def enter(self,visitee):
00043         if (hasattr(visitee,self._paramName)):
00044             if getattr(visitee,self._paramName) == self._paramSearch:
00045                 if self._verbose:print "Replaced %s.%s: %s => %s" % (visitee,self._paramName,getattr(visitee,self._paramName),self._paramValue)
00046                 setattr(visitee,self._paramName,self._paramValue)
00047     def leave(self,visitee):
00048         pass
00049 
00050 class MassSearchReplaceAnyInputTagVisitor(object):
00051     """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
00052        It will climb down within PSets, VPSets and VInputTags to find its target"""
00053     def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False):
00054         self._paramSearch  = self.standardizeInputTagFmt(paramSearch)
00055         self._paramReplace = self.standardizeInputTagFmt(paramReplace)
00056         self._moduleName   = ''
00057         self._verbose=verbose
00058         self._moduleLabelOnly=moduleLabelOnly
00059     def doIt(self,pset,base):
00060         if isinstance(pset, cms._Parameterizable):
00061             for name in pset.parameters_().keys():
00062                 # if I use pset.parameters_().items() I get copies of the parameter values
00063                 # so I can't modify the nested pset
00064                 value = getattr(pset,name) 
00065                 type = value.pythonTypeName()
00066                 if type == 'cms.PSet':  
00067                     self.doIt(value,base+"."+name)
00068                 elif type == 'cms.VPSet':
00069                     for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
00070                 elif type == 'cms.VInputTag':
00071                     for (i,n) in enumerate(value): 
00072                          # VInputTag can be declared as a list of strings, so ensure that n is formatted correctly
00073                          n = self.standardizeInputTagFmt(n)
00074                          if (n == self._paramSearch):
00075                             if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace)
00076                             value[i] = self._paramReplace
00077                          elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
00078                             nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
00079                             if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep)
00080                             value[i] = nrep
00081                 elif type == 'cms.InputTag':
00082                     if value == self._paramSearch:
00083                         if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace)
00084                         from copy import deepcopy
00085                         setattr(pset, name, deepcopy(self._paramReplace) )
00086                     elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
00087                         from copy import deepcopy
00088                         repl = deepcopy(getattr(pset, name))
00089                         repl.moduleLabel = self._paramReplace.moduleLabel
00090                         setattr(pset, name, repl)
00091                         if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, value, repl)
00092                         
00093 
00094     @staticmethod 
00095     def standardizeInputTagFmt(inputTag):
00096        ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
00097        if not isinstance(inputTag, cms.InputTag):
00098           return cms.InputTag(inputTag)
00099        return inputTag
00100 
00101     def enter(self,visitee):
00102         label = ''
00103         try:    label = visitee.label()
00104         except AttributeError: label = '<Module not in a Process>'
00105         self.doIt(visitee, label)
00106     def leave(self,visitee):
00107         pass
00108 
00109 #FIXME name is not generic enough now
00110 class GatherAllModulesVisitor(object):
00111     """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
00112     def __init__(self, gatheredInstance=cms._Module):
00113         self._modules = []
00114         self._gatheredInstance= gatheredInstance
00115     def enter(self,visitee):
00116         if isinstance(visitee,self._gatheredInstance):
00117             self._modules.append(visitee)
00118     def leave(self,visitee):
00119         pass
00120     def modules(self):
00121         return self._modules
00122 
00123 class CloneSequenceVisitor(object):
00124     """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
00125     All modules and sequences are cloned and a postfix is added"""
00126     def __init__(self, process, label, postfix):
00127         self._process = process
00128         self._postfix = postfix
00129         self._sequenceStack = [label]
00130         self._moduleLabels = []
00131         self._sequenceLabels = []
00132         self._waitForSequenceToClose = None # modules will only be cloned or added if this is None
00133 
00134     def enter(self,visitee):
00135         if not self._waitForSequenceToClose is None:
00136             return #we are in a already cloned sequence
00137         if isinstance(visitee,cms._Module):
00138             label = visitee.label()
00139             newModule = None
00140             if label in self._moduleLabels:
00141                 newModule = getattr(self._process, label+self._postfix)
00142             else:
00143                 self._moduleLabels.append(label)
00144                 
00145                 newModule = visitee.clone()
00146                 setattr(self._process, label+self._postfix, newModule)
00147             self.__appendToTopSequence(newModule)
00148 
00149         if isinstance(visitee,cms.Sequence):
00150             if visitee.label() in self._sequenceLabels: # is the sequence allready cloned?
00151                 self._waitForSequenceToClose = visitee.label()
00152                 self._sequenceStack.append(  getattr(self._process, visitee.label()+self._postfix) )
00153             else:
00154                 self._sequenceStack.append(visitee.label())#save desired label as placeholder until we have a module to create the sequence
00155 
00156     def leave(self,visitee):
00157         if isinstance(visitee,cms.Sequence):
00158             if self._waitForSequenceToClose == visitee.label():
00159                 self._waitForSequenceToClose = None
00160             if not isinstance(self._sequenceStack[-1], cms.Sequence):
00161                 raise StandardError, "empty Sequence encountered during cloneing. sequnece stack: %s"%self._sequenceStack
00162             self.__appendToTopSequence( self._sequenceStack.pop() )
00163 
00164     def clonedSequence(self):
00165         if not len(self._sequenceStack) == 1:
00166             raise StandardError, "someting went wrong, the sequence stack looks like: %s"%self._sequenceStack
00167         for label in self._moduleLabels:
00168             massSearchReplaceAnyInputTag(self._sequenceStack[-1], label, label+self._postfix, moduleLabelOnly=True, verbose=False)
00169         self._moduleLabels = [] #prevent the InputTag replacement next time this is called.
00170         return self._sequenceStack[-1]
00171 
00172     def __appendToTopSequence(self, visitee):#this is darn ugly because empty cms.Sequences are not supported
00173         if isinstance(self._sequenceStack[-1], basestring):#we have the name of an empty sequence on the stack. create it!
00174             oldSequenceLabel = self._sequenceStack.pop()
00175             newSequenceLabel = oldSequenceLabel + self._postfix
00176             self._sequenceStack.append(cms.Sequence(visitee))
00177             if hasattr(self._process, newSequenceLabel):
00178                 raise StandardError("Cloning the sequence "+self._sequenceStack[-1].label()+" would overwrite existing object." )
00179             setattr(self._process, newSequenceLabel, self._sequenceStack[-1])
00180             self._sequenceLabels.append(oldSequenceLabel)
00181         else:
00182             self._sequenceStack[-1] += visitee
00183         
00184 class MassSearchParamVisitor(object):
00185     """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
00186     def __init__(self,paramName,paramSearch):
00187         self._paramName   = paramName
00188         self._paramSearch = paramSearch
00189         self._modules = []
00190     def enter(self,visitee):
00191         if (hasattr(visitee,self._paramName)):
00192             if getattr(visitee,self._paramName) == self._paramSearch:
00193                 self._modules.append(visitee)
00194     def leave(self,visitee):
00195         pass
00196     def modules(self):
00197         return self._modules
00198     
00199     
00200 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
00201     sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
00202 
00203 def listModules(sequence):
00204     visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
00205     sequence.visit(visitor)
00206     return visitor.modules()
00207 
00208 def listSequences(sequence):
00209     visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
00210     sequence.visit(visitor)
00211     return visitor.modules()
00212 
00213 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False) : 
00214     """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
00215     sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly))
00216     
00217 def jetCollectionString(prefix='', algo='', type=''):
00218     """
00219     ------------------------------------------------------------------
00220     return the string of the jet collection module depending on the
00221     input vaules. The default return value will be 'patAK5CaloJets'.
00222 
00223     algo   : indicating the algorithm type of the jet [expected are
00224              'AK5', 'IC5', 'SC7', ...]
00225     type   : indicating the type of constituents of the jet [expec-
00226              ted are 'Calo', 'PFlow', 'JPT', ...]
00227     prefix : prefix indicating the type of pat collection module (ex-
00228              pected are '', 'selected', 'clean').
00229     ------------------------------------------------------------------    
00230     """
00231     if(prefix==''):
00232         jetCollectionString ='pat'
00233     else:
00234         jetCollectionString =prefix
00235         jetCollectionString+='Pat'
00236     jetCollectionString+='Jets'        
00237     jetCollectionString+=algo
00238     jetCollectionString+=type
00239     return jetCollectionString
00240 
00241 def contains(sequence, moduleName):
00242     """
00243     ------------------------------------------------------------------
00244     return True if a module with name 'module' is contained in the 
00245     sequence with name 'sequence' and False otherwise. This version
00246     is not so nice as it also returns True for any substr of the name
00247     of a contained module.
00248 
00249     sequence : sequence [e.g. process.patDefaultSequence]
00250     module   : module name as a string
00251     ------------------------------------------------------------------    
00252     """
00253     return not sequence.__str__().find(moduleName)==-1    
00254 
00255 
00256 
00257 def cloneProcessingSnippet(process, sequence, postfix):
00258    """
00259    ------------------------------------------------------------------
00260    copy a sequence plus the modules and sequences therein 
00261    both are renamed by getting a postfix
00262    input tags are automatically adjusted
00263    ------------------------------------------------------------------
00264    """
00265    result = sequence
00266    if not postfix == "":
00267        visitor = CloneSequenceVisitor(process,sequence.label(),postfix)
00268        sequence.visit(visitor)
00269        result = visitor.clonedSequence()    
00270    return result
00271 
00272 if __name__=="__main__":
00273    import unittest
00274    class TestModuleCommand(unittest.TestCase):
00275        def setUp(self):
00276            """Nothing to do """
00277            pass
00278        def testCloning(self):
00279            p = cms.Process("test")
00280            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00281            p.b = cms.EDProducer("b", src=cms.InputTag("a"))
00282            p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
00283            p.s = cms.Sequence(p.a*p.b*p.c *p.a)
00284            cloneProcessingSnippet(p, p.s, "New")
00285            self.assertEqual(p.dumpPython(),'import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")\n\nprocess.a = cms.EDProducer("a",\n    src = cms.InputTag("gen")\n)\n\n\nprocess.c = cms.EDProducer("c",\n    src = cms.InputTag("b","instance")\n)\n\n\nprocess.cNew = cms.EDProducer("c",\n    src = cms.InputTag("bNew","instance")\n)\n\n\nprocess.bNew = cms.EDProducer("b",\n    src = cms.InputTag("aNew")\n)\n\n\nprocess.aNew = cms.EDProducer("a",\n    src = cms.InputTag("gen")\n)\n\n\nprocess.b = cms.EDProducer("b",\n    src = cms.InputTag("a")\n)\n\n\nprocess.s = cms.Sequence(process.a*process.b*process.c*process.a)\n\n\nprocess.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew)\n\n\n')
00286        def testContains(self):
00287            p = cms.Process("test")
00288            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00289            p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00290            p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00291            p.s1 = cms.Sequence(p.a*p.b*p.c)
00292            p.s2 = cms.Sequence(p.b*p.c)
00293            self.assert_( contains(p.s1, "a") )
00294            self.assert_( not contains(p.s2, "a") )
00295        def testJetCollectionString(self):
00296            self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patFooBarJets')
00297            self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatFooBarJets')
00298        def testListModules(self):
00299            p = cms.Process("test")
00300            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00301            p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00302            p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00303            p.s = cms.Sequence(p.a*p.b*p.c)
00304            self.assertEqual([p.a,p.b,p.c], listModules(p.s))
00305        def testMassSearchReplaceParam(self):
00306            p = cms.Process("test")
00307            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00308            p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00309            p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
00310                                 nested = cms.PSet(src = cms.InputTag("c"))
00311                                )
00312            p.s = cms.Sequence(p.a*p.b*p.c)
00313            massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
00314            self.assertEqual(cms.InputTag("a"),p.c.src)
00315            self.assertNotEqual(cms.InputTag("a"),p.c.nested.src)
00316            
00317    unittest.main()