CMS 3D CMS Logo

HiHelperTools.py
Go to the documentation of this file.
1 import FWCore.ParameterSet.Config as cms
2 
3 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
4 ## and replacing that to a given value
5 
6 # Next two lines are for backward compatibility, the imported functions and
7 # classes used to be defined in this file.
8 from FWCore.ParameterSet.MassReplace import massSearchReplaceAnyInputTag, MassSearchReplaceAnyInputTagVisitor
9 from FWCore.ParameterSet.MassReplace import massSearchReplaceParam, MassSearchParamVisitor, MassSearchReplaceParamVisitor
10 
11 def applyPostfix(process, label, postfix):
12  ''' If a module is in patHeavyIonDefaultSequence use the cloned module.
13  Will crash if patHeavyIonDefaultSequence has not been cloned with 'postfix' beforehand'''
14  result = None
15  defaultLabels = __labelsInSequence(process, "patHeavyIonDefaultSequence", postfix)
16  if hasattr(process, "patPF2PATSequence"):
17  defaultLabels = __labelsInSequence(process, "patPF2PATSequence", postfix)
18  if label in defaultLabels and hasattr(process, label+postfix):
19  result = getattr(process, label+postfix)
20  elif hasattr(process, label):
21  print "WARNING: called applyPostfix for module/sequence %s which is not in patHeavyIonDefaultSequence%s!"%(label,postfix)
22  result = getattr(process, label)
23  return result
24 
25 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
26  labels = __labelsInSequence(process, sequenceLabel, postfix)
27  if target+postfix in labels:
28  getattr(process, sequenceLabel+postfix).remove(
29  getattr(process, target+postfix)
30  )
31 
32 def __labelsInSequence(process, sequenceLabel, postfix=""):
33  result = [ m.label()[:-len(postfix)] for m in listModules( getattr(process,sequenceLabel+postfix))]
34  result.extend([ m.label()[:-len(postfix)] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
35  if postfix == "":
36  result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
37  result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
38  return result
39 
40 #FIXME name is not generic enough now
42  """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
43  def __init__(self, gatheredInstance=cms._Module):
44  self._modules = []
45  self._gatheredInstance= gatheredInstance
46  def enter(self,visitee):
47  if isinstance(visitee,self._gatheredInstance):
48  self._modules.append(visitee)
49  def leave(self,visitee):
50  pass
51  def modules(self):
52  return self._modules
53 
55  """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
56  All modules and sequences are cloned and a postfix is added"""
57  def __init__(self, process, label, postfix):
58  self._process = process
59  self._postfix = postfix
60  self._sequenceStack = [label]
61  self._moduleLabels = []
62  self._sequenceLabels = []
63  self._waitForSequenceToClose = None # modules will only be cloned or added if this is None
64 
65  def enter(self,visitee):
66  if not self._waitForSequenceToClose is None:
67  return #we are in a already cloned sequence
68  if isinstance(visitee,cms._Module):
69  label = visitee.label()
70  newModule = None
71  if label in self._moduleLabels:
72  newModule = getattr(self._process, label+self._postfix)
73  else:
74  self._moduleLabels.append(label)
75 
76  newModule = visitee.clone()
77  setattr(self._process, label+self._postfix, newModule)
78  self.__appendToTopSequence(newModule)
79 
80  if isinstance(visitee,cms.Sequence):
81  if visitee.label() in self._sequenceLabels: # is the sequence allready cloned?
82  self._waitForSequenceToClose = visitee.label()
83  self._sequenceStack.append( getattr(self._process, visitee.label()+self._postfix) )
84  else:
85  self._sequenceStack.append(visitee.label())#save desired label as placeholder until we have a module to create the sequence
86 
87  def leave(self,visitee):
88  if isinstance(visitee,cms.Sequence):
89  if self._waitForSequenceToClose == visitee.label():
90  self._waitForSequenceToClose = None
91  if not isinstance(self._sequenceStack[-1], cms.Sequence):
92  raise Exception("empty Sequence encountered during cloneing. sequnece stack: %s"%self._sequenceStack)
93  self.__appendToTopSequence( self._sequenceStack.pop() )
94 
95  def clonedSequence(self):
96  if not len(self._sequenceStack) == 1:
97  raise Exception("someting went wrong, the sequence stack looks like: %s"%self._sequenceStack)
98  for label in self._moduleLabels:
99  massSearchReplaceAnyInputTag(self._sequenceStack[-1], label, label+self._postfix, moduleLabelOnly=True, verbose=False)
100  self._moduleLabels = [] #prevent the InputTag replacement next time this is called.
101  return self._sequenceStack[-1]
102 
103  def __appendToTopSequence(self, visitee):#this is darn ugly because empty cms.Sequences are not supported
104  if isinstance(self._sequenceStack[-1], str):#we have the name of an empty sequence on the stack. create it!
105  oldSequenceLabel = self._sequenceStack.pop()
106  newSequenceLabel = oldSequenceLabel + self._postfix
107  self._sequenceStack.append(cms.Sequence(visitee))
108  if hasattr(self._process, newSequenceLabel):
109  raise Exception("Cloning the sequence "+self._sequenceStack[-1].label()+" would overwrite existing object." )
110  setattr(self._process, newSequenceLabel, self._sequenceStack[-1])
111  self._sequenceLabels.append(oldSequenceLabel)
112  else:
113  self._sequenceStack[-1] += visitee
114 
115 def listModules(sequence):
116  visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
117  sequence.visit(visitor)
118  return visitor.modules()
119 
120 def listSequences(sequence):
121  visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
122  sequence.visit(visitor)
123  return visitor.modules()
124 
125 def jetCollectionString(prefix='', algo='', type=''):
126  """
127  ------------------------------------------------------------------
128  return the string of the jet collection module depending on the
129  input vaules. The default return value will be 'patAK5CaloJets'.
130 
131  algo : indicating the algorithm type of the jet [expected are
132  'AK5', 'IC5', 'SC7', ...]
133  type : indicating the type of constituents of the jet [expec-
134  ted are 'Calo', 'PFlow', 'JPT', ...]
135  prefix : prefix indicating the type of pat collection module (ex-
136  pected are '', 'selected', 'clean').
137  ------------------------------------------------------------------
138  """
139  if(prefix==''):
140  jetCollectionString ='pat'
141  else:
142  jetCollectionString =prefix
143  jetCollectionString+='Pat'
144  jetCollectionString+='Jets'
145  jetCollectionString+=algo
146  jetCollectionString+=type
147  return jetCollectionString
148 
149 def contains(sequence, moduleName):
150  """
151  ------------------------------------------------------------------
152  return True if a module with name 'module' is contained in the
153  sequence with name 'sequence' and False otherwise. This version
154  is not so nice as it also returns True for any substr of the name
155  of a contained module.
156 
157  sequence : sequence [e.g. process.patHeavyIonDefaultSequence]
158  module : module name as a string
159  ------------------------------------------------------------------
160  """
161  return not sequence.__str__().find(moduleName)==-1
162 
163 
164 
165 def cloneProcessingSnippet(process, sequence, postfix):
166  """
167  ------------------------------------------------------------------
168  copy a sequence plus the modules and sequences therein
169  both are renamed by getting a postfix
170  input tags are automatically adjusted
171  ------------------------------------------------------------------
172  """
173  result = sequence
174  if not postfix == "":
175  visitor = CloneSequenceVisitor(process,sequence.label(),postfix)
176  sequence.visit(visitor)
177  result = visitor.clonedSequence()
178  return result
179 
180 if __name__=="__main__":
181  import unittest
182  class TestModuleCommand(unittest.TestCase):
183  def setUp(self):
184  """Nothing to do """
185  pass
186  def testCloning(self):
187  p = cms.Process("test")
188  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
189  p.b = cms.EDProducer("b", src=cms.InputTag("a"))
190  p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
191  p.s = cms.Sequence(p.a*p.b*p.c *p.a)
192  cloneProcessingSnippet(p, p.s, "New")
193  self.assertEqual(p.dumpPython(),'import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")\n\nprocess.a = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.c = cms.EDProducer("c",\n src = cms.InputTag("b","instance")\n)\n\n\nprocess.cNew = cms.EDProducer("c",\n src = cms.InputTag("bNew","instance")\n)\n\n\nprocess.bNew = cms.EDProducer("b",\n src = cms.InputTag("aNew")\n)\n\n\nprocess.aNew = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.b = cms.EDProducer("b",\n src = cms.InputTag("a")\n)\n\n\nprocess.s = cms.Sequence(process.a*process.b*process.c*process.a)\n\n\nprocess.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew)\n\n\n')
194  def testContains(self):
195  p = cms.Process("test")
196  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
197  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
198  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
199  p.s1 = cms.Sequence(p.a*p.b*p.c)
200  p.s2 = cms.Sequence(p.b*p.c)
201  self.assert_( contains(p.s1, "a") )
202  self.assert_( not contains(p.s2, "a") )
204  self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patFooBarJets')
205  self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatFooBarJets')
206  def testListModules(self):
207  p = cms.Process("test")
208  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
209  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
210  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
211  p.s = cms.Sequence(p.a*p.b*p.c)
212  self.assertEqual([p.a,p.b,p.c], listModules(p.s))
213 
214  unittest.main()
def cloneProcessingSnippet(process, sequence, postfix)
def contains(sequence, moduleName)
def listModules(sequence)
def applyPostfix(process, label, postfix)
def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag, verbose=False, moduleLabelOnly=False, skipLabelTest=False)
Definition: MassReplace.py:72
def listSequences(sequence)
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:20
def __init__(self, gatheredInstance=cms._Module)
def jetCollectionString(prefix='', algo='', type='')
def removeIfInSequence(process, target, sequenceLabel, postfix="")
def remove(d, key, TELL=False)
Definition: MatrixUtil.py:211
def __init__(self, process, label, postfix)
def __labelsInSequence(process, sequenceLabel, postfix="")
def __appendToTopSequence(self, visitee)