CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_6_1_2_SLHC2/src/PhysicsTools/PatAlgos/python/tools/helpers.py

Go to the documentation of this file.
00001 import FWCore.ParameterSet.Config as cms
00002 
00003 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
00004 ## and replacing that to a given value
00005 
00006 def applyPostfix(process, label, postfix):
00007     ''' If a module is in patDefaultSequence use the cloned module.
00008     Will crash if patDefaultSequence has not been cloned with 'postfix' beforehand'''
00009     result = None 
00010     defaultLabels = __labelsInSequence(process, "patDefaultSequence", postfix)
00011     if hasattr(process, "patPF2PATSequence"):
00012         defaultLabels = __labelsInSequence(process, "patPF2PATSequence", postfix)
00013     if label in defaultLabels and hasattr(process, label+postfix):
00014         result = getattr(process, label+postfix)
00015     elif hasattr(process, label):
00016         print "WARNING: called applyPostfix for module/sequence %s which is not in patDefaultSequence%s!"%(label,postfix)
00017         result = getattr(process, label)    
00018     return result
00019 
00020 def removeIfInSequence(process, target,  sequenceLabel, postfix=""):
00021     labels = __labelsInSequence(process, sequenceLabel, postfix)
00022     if target+postfix in labels: 
00023         getattr(process, sequenceLabel+postfix).remove(
00024             getattr(process, target+postfix)
00025             )
00026     
00027 def __labelsInSequence(process, sequenceLabel, postfix=""):
00028     result = [ m.label()[:-len(postfix)] for m in listModules( getattr(process,sequenceLabel+postfix))]
00029     result.extend([ m.label()[:-len(postfix)] for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
00030     if postfix == "":  
00031         result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
00032         result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
00033     return result
00034     
00035 class MassSearchReplaceParamVisitor(object):
00036     """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
00037     def __init__(self,paramName,paramSearch,paramValue,verbose=False):
00038         self._paramName   = paramName
00039         self._paramValue  = paramValue
00040         self._paramSearch = paramSearch
00041         self._verbose = verbose
00042     def enter(self,visitee):
00043         if (hasattr(visitee,self._paramName)):
00044             if getattr(visitee,self._paramName) == self._paramSearch:
00045                 if self._verbose:print "Replaced %s.%s: %s => %s" % (visitee,self._paramName,getattr(visitee,self._paramName),self._paramValue)
00046                 setattr(visitee,self._paramName,self._paramValue)
00047     def leave(self,visitee):
00048         pass
00049 
00050 class MassSearchReplaceAnyInputTagVisitor(object):
00051     """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
00052        It will climb down within PSets, VPSets and VInputTags to find its target"""
00053     def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False):
00054         self._paramSearch  = self.standardizeInputTagFmt(paramSearch)
00055         self._paramReplace = self.standardizeInputTagFmt(paramReplace)
00056         self._moduleName   = ''
00057         self._verbose=verbose
00058         self._moduleLabelOnly=moduleLabelOnly
00059     def doIt(self,pset,base):
00060         if isinstance(pset, cms._Parameterizable):
00061             for name in pset.parameters_().keys():
00062                 # if I use pset.parameters_().items() I get copies of the parameter values
00063                 # so I can't modify the nested pset
00064                 value = getattr(pset,name) 
00065                 type = value.pythonTypeName()
00066                 if type == 'cms.PSet':  
00067                     self.doIt(value,base+"."+name)
00068                 elif type == 'cms.VPSet':
00069                     for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
00070                 elif type == 'cms.VInputTag':
00071                     for (i,n) in enumerate(value): 
00072                          # VInputTag can be declared as a list of strings, so ensure that n is formatted correctly
00073                          n = self.standardizeInputTagFmt(n)
00074                          if (n == self._paramSearch):
00075                             if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace)
00076                             value[i] = self._paramReplace
00077                          elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
00078                             nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
00079                             if self._verbose:print "Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep)
00080                             value[i] = nrep
00081                 elif type.endswith('.InputTag'):
00082                     if value == self._paramSearch:
00083                         if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace)
00084                         from copy import deepcopy
00085                         if 'untracked' in type:
00086                             setattr(pset, name, cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
00087                                                                        self._paramReplace.getProductInstanceLabel(),
00088                                                                        self._paramReplace.getProcessName()))
00089                         else:
00090                             setattr(pset, name, deepcopy(self._paramReplace) )
00091                     elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
00092                         from copy import deepcopy
00093                         repl = deepcopy(getattr(pset, name))
00094                         repl.moduleLabel = self._paramReplace.moduleLabel
00095                         setattr(pset, name, repl)
00096                         if self._verbose:print "Replace %s.%s %s ==> %s " % (base, name, value, repl)
00097                         
00098 
00099     @staticmethod 
00100     def standardizeInputTagFmt(inputTag):
00101        ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
00102        if not isinstance(inputTag, cms.InputTag):
00103           return cms.InputTag(inputTag)
00104        return inputTag
00105 
00106     def enter(self,visitee):
00107         label = ''
00108         try:    label = visitee.label_()
00109         except AttributeError: label = '<Module not in a Process>'
00110         self.doIt(visitee, label)
00111     def leave(self,visitee):
00112         pass
00113 
00114 #FIXME name is not generic enough now
00115 class GatherAllModulesVisitor(object):
00116     """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
00117     def __init__(self, gatheredInstance=cms._Module):
00118         self._modules = []
00119         self._gatheredInstance= gatheredInstance
00120     def enter(self,visitee):
00121         if isinstance(visitee,self._gatheredInstance):
00122             self._modules.append(visitee)
00123     def leave(self,visitee):
00124         pass
00125     def modules(self):
00126         return self._modules
00127 
00128 class CloneSequenceVisitor(object):
00129     """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
00130     All modules and sequences are cloned and a postfix is added"""
00131     def __init__(self, process, label, postfix, removePostfix=""):
00132         self._process = process
00133         self._postfix = postfix
00134         self._removePostfix = removePostfix
00135         self._moduleLabels = []
00136         self._clonedSequence = cms.Sequence()
00137         setattr(process, self._newLabel(label), self._clonedSequence)
00138 
00139     def enter(self, visitee):
00140         if isinstance(visitee, cms._Module):
00141             label = visitee.label()
00142             newModule = None
00143             if label in self._moduleLabels: # has the module already been cloned ?
00144                 newModule = getattr(self._process, self._newLabel(label))
00145             else:
00146                 self._moduleLabels.append(label)                
00147                 newModule = visitee.clone()
00148                 setattr(self._process, self._newLabel(label), newModule)
00149             self.__appendToTopSequence(newModule)
00150 
00151     def leave(self, visitee):
00152         pass
00153 
00154     def clonedSequence(self):
00155         for label in self._moduleLabels:
00156             massSearchReplaceAnyInputTag(self._clonedSequence, label, self._newLabel(label), moduleLabelOnly=True, verbose=False)
00157         self._moduleLabels = [] # prevent the InputTag replacement next time the 'clonedSequence' function is called.
00158         return self._clonedSequence
00159 
00160     def _newLabel(self, label):
00161         if self._removePostfix != "":
00162             if label[-len(self._removePostfix):] == self._removePostfix:
00163                 label = label[0:-len(self._removePostfix)]
00164             else:
00165                 raise StandardError("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
00166         return label + self._postfix
00167 
00168     def __appendToTopSequence(self, visitee):
00169         self._clonedSequence += visitee
00170         
00171 class MassSearchParamVisitor(object):
00172     """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
00173     def __init__(self,paramName,paramSearch):
00174         self._paramName   = paramName
00175         self._paramSearch = paramSearch
00176         self._modules = []
00177     def enter(self,visitee):
00178         if (hasattr(visitee,self._paramName)):
00179             if getattr(visitee,self._paramName) == self._paramSearch:
00180                 self._modules.append(visitee)
00181     def leave(self,visitee):
00182         pass
00183     def modules(self):
00184         return self._modules
00185     
00186     
00187 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
00188     sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
00189 
00190 def listModules(sequence):
00191     visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
00192     sequence.visit(visitor)
00193     return visitor.modules()
00194 
00195 def listSequences(sequence):
00196     visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
00197     sequence.visit(visitor)
00198     return visitor.modules()
00199 
00200 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False) : 
00201     """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
00202     sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly))
00203     
00204 def jetCollectionString(prefix='', algo='', type=''):
00205     """
00206     ------------------------------------------------------------------
00207     return the string of the jet collection module depending on the
00208     input vaules. The default return value will be 'patAK5CaloJets'.
00209 
00210     algo   : indicating the algorithm type of the jet [expected are
00211              'AK5', 'IC5', 'SC7', ...]
00212     type   : indicating the type of constituents of the jet [expec-
00213              ted are 'Calo', 'PFlow', 'JPT', ...]
00214     prefix : prefix indicating the type of pat collection module (ex-
00215              pected are '', 'selected', 'clean').
00216     ------------------------------------------------------------------    
00217     """
00218     if(prefix==''):
00219         jetCollectionString ='pat'
00220     else:
00221         jetCollectionString =prefix
00222         jetCollectionString+='Pat'
00223     jetCollectionString+='Jets'        
00224     jetCollectionString+=algo
00225     jetCollectionString+=type
00226     return jetCollectionString
00227 
00228 def contains(sequence, moduleName):
00229     """
00230     ------------------------------------------------------------------
00231     return True if a module with name 'module' is contained in the 
00232     sequence with name 'sequence' and False otherwise. This version
00233     is not so nice as it also returns True for any substr of the name
00234     of a contained module.
00235 
00236     sequence : sequence [e.g. process.patDefaultSequence]
00237     module   : module name as a string
00238     ------------------------------------------------------------------    
00239     """
00240     return not sequence.__str__().find(moduleName)==-1    
00241 
00242 
00243 
00244 def cloneProcessingSnippet(process, sequence, postfix, removePostfix=""):
00245    """
00246    ------------------------------------------------------------------
00247    copy a sequence plus the modules and sequences therein 
00248    both are renamed by getting a postfix
00249    input tags are automatically adjusted
00250    ------------------------------------------------------------------
00251    """
00252    result = sequence
00253    if not postfix == "":
00254        visitor = CloneSequenceVisitor(process, sequence.label(), postfix, removePostfix)
00255        sequence.visit(visitor)
00256        result = visitor.clonedSequence()    
00257    return result
00258 
00259 if __name__=="__main__":
00260    import unittest
00261    class TestModuleCommand(unittest.TestCase):
00262        def setUp(self):
00263            """Nothing to do """
00264            pass
00265        def testCloning(self):
00266            p = cms.Process("test")
00267            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00268            p.b = cms.EDProducer("b", src=cms.InputTag("a"))
00269            p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
00270            p.s = cms.Sequence(p.a*p.b*p.c *p.a)
00271            cloneProcessingSnippet(p, p.s, "New")
00272            self.assertEqual(p.dumpPython(),'import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")\n\nprocess.a = cms.EDProducer("a",\n    src = cms.InputTag("gen")\n)\n\n\nprocess.c = cms.EDProducer("c",\n    src = cms.InputTag("b","instance")\n)\n\n\nprocess.cNew = cms.EDProducer("c",\n    src = cms.InputTag("bNew","instance")\n)\n\n\nprocess.bNew = cms.EDProducer("b",\n    src = cms.InputTag("aNew")\n)\n\n\nprocess.aNew = cms.EDProducer("a",\n    src = cms.InputTag("gen")\n)\n\n\nprocess.b = cms.EDProducer("b",\n    src = cms.InputTag("a")\n)\n\n\nprocess.s = cms.Sequence(process.a*process.b*process.c*process.a)\n\n\nprocess.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew)\n\n\n')
00273        def testContains(self):
00274            p = cms.Process("test")
00275            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00276            p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00277            p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00278            p.s1 = cms.Sequence(p.a*p.b*p.c)
00279            p.s2 = cms.Sequence(p.b*p.c)
00280            self.assert_( contains(p.s1, "a") )
00281            self.assert_( not contains(p.s2, "a") )
00282        def testJetCollectionString(self):
00283            self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patFooBarJets')
00284            self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatFooBarJets')
00285        def testListModules(self):
00286            p = cms.Process("test")
00287            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00288            p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00289            p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
00290            p.s = cms.Sequence(p.a*p.b*p.c)
00291            self.assertEqual([p.a,p.b,p.c], listModules(p.s))
00292        def testMassSearchReplaceParam(self):
00293            p = cms.Process("test")
00294            p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
00295            p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
00296            p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
00297                                 nested = cms.PSet(src = cms.InputTag("c"))
00298                                )
00299            p.s = cms.Sequence(p.a*p.b*p.c)
00300            massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
00301            self.assertEqual(cms.InputTag("a"),p.c.src)
00302            self.assertNotEqual(cms.InputTag("a"),p.c.nested.src)
00303            
00304    unittest.main()