CMS 3D CMS Logo

HiHelperTools.py
Go to the documentation of this file.
1 from __future__ import print_function
2 import FWCore.ParameterSet.Config as cms
3 
4 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
5 ## and replacing that to a given value
6 
7 # Next two lines are for backward compatibility, the imported functions and
8 # classes used to be defined in this file.
9 from FWCore.ParameterSet.MassReplace import massSearchReplaceAnyInputTag, MassSearchReplaceAnyInputTagVisitor
10 from FWCore.ParameterSet.MassReplace import massSearchReplaceParam, MassSearchParamVisitor, MassSearchReplaceParamVisitor
11 
12 def applyPostfix(process, label, postfix):
13  ''' If a module is in patHeavyIonDefaultSequence use the cloned module.
14  Will crash if patHeavyIonDefaultSequence has not been cloned with 'postfix' beforehand'''
15  result = None
16  defaultLabels = __labelsInSequence(process, "patHeavyIonDefaultSequence", postfix)
17  if hasattr(process, "patPF2PATSequence"):
18  defaultLabels = __labelsInSequence(process, "patPF2PATSequence", postfix)
19  if label in defaultLabels and hasattr(process, label+postfix):
20  result = getattr(process, label+postfix)
21  elif hasattr(process, label):
22  print("WARNING: called applyPostfix for module/sequence %s which is not in patHeavyIonDefaultSequence%s!"%(label,postfix))
23  result = getattr(process, label)
24  return result
25 
26 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
27  labels = __labelsInSequence(process, sequenceLabel, postfix)
28  if target+postfix in labels:
29  getattr(process, sequenceLabel+postfix).remove(
30  getattr(process, target+postfix)
31  )
32 
33 def __labelsInSequence(process, sequenceLabel, postfix=""):
34  result = [ m.label()[:-len(postfix)] for m in listModules( getattr(process,sequenceLabel+postfix))]
35  result.extend([ m.label()[:-len(postfix)] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
36  if postfix == "":
37  result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
38  result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
39  return result
40 
41 #FIXME name is not generic enough now
43  """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
44  def __init__(self, gatheredInstance=cms._Module):
45  self._modules = []
46  self._gatheredInstance= gatheredInstance
47  def enter(self,visitee):
48  if isinstance(visitee,self._gatheredInstance):
49  self._modules.append(visitee)
50  def leave(self,visitee):
51  pass
52  def modules(self):
53  return self._modules
54 
56  """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
57  All modules and sequences are cloned and a postfix is added"""
58  def __init__(self, process, label, postfix):
59  self._process = process
60  self._postfix = postfix
61  self._sequenceStack = [label]
62  self._moduleLabels = []
63  self._sequenceLabels = []
64  self._waitForSequenceToClose = None # modules will only be cloned or added if this is None
65 
66  def enter(self,visitee):
67  if not self._waitForSequenceToClose is None:
68  return #we are in a already cloned sequence
69  if isinstance(visitee,cms._Module):
70  label = visitee.label()
71  newModule = None
72  if label in self._moduleLabels:
73  newModule = getattr(self._process, label+self._postfix)
74  else:
75  self._moduleLabels.append(label)
76 
77  newModule = visitee.clone()
78  setattr(self._process, label+self._postfix, newModule)
79  self.__appendToTopSequence(newModule)
80 
81  if isinstance(visitee,cms.Sequence):
82  if visitee.label() in self._sequenceLabels: # is the sequence allready cloned?
83  self._waitForSequenceToClose = visitee.label()
84  self._sequenceStack.append( getattr(self._process, visitee.label()+self._postfix) )
85  else:
86  self._sequenceStack.append(visitee.label())#save desired label as placeholder until we have a module to create the sequence
87 
88  def leave(self,visitee):
89  if isinstance(visitee,cms.Sequence):
90  if self._waitForSequenceToClose == visitee.label():
91  self._waitForSequenceToClose = None
92  if not isinstance(self._sequenceStack[-1], cms.Sequence):
93  raise Exception("empty Sequence encountered during cloneing. sequnece stack: %s"%self._sequenceStack)
94  self.__appendToTopSequence( self._sequenceStack.pop() )
95 
96  def clonedSequence(self):
97  if not len(self._sequenceStack) == 1:
98  raise Exception("someting went wrong, the sequence stack looks like: %s"%self._sequenceStack)
99  for label in self._moduleLabels:
100  massSearchReplaceAnyInputTag(self._sequenceStack[-1], label, label+self._postfix, moduleLabelOnly=True, verbose=False)
101  self._moduleLabels = [] #prevent the InputTag replacement next time this is called.
102  return self._sequenceStack[-1]
103 
104  def __appendToTopSequence(self, visitee):#this is darn ugly because empty cms.Sequences are not supported
105  if isinstance(self._sequenceStack[-1], str):#we have the name of an empty sequence on the stack. create it!
106  oldSequenceLabel = self._sequenceStack.pop()
107  newSequenceLabel = oldSequenceLabel + self._postfix
108  self._sequenceStack.append(cms.Sequence(visitee))
109  if hasattr(self._process, newSequenceLabel):
110  raise Exception("Cloning the sequence "+self._sequenceStack[-1].label()+" would overwrite existing object." )
111  setattr(self._process, newSequenceLabel, self._sequenceStack[-1])
112  self._sequenceLabels.append(oldSequenceLabel)
113  else:
114  self._sequenceStack[-1] += visitee
115 
116 def listModules(sequence):
117  visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
118  sequence.visit(visitor)
119  return visitor.modules()
120 
121 def listSequences(sequence):
122  visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
123  sequence.visit(visitor)
124  return visitor.modules()
125 
126 def jetCollectionString(prefix='', algo='', type=''):
127  """
128  ------------------------------------------------------------------
129  return the string of the jet collection module depending on the
130  input vaules. The default return value will be 'patAK5CaloJets'.
131 
132  algo : indicating the algorithm type of the jet [expected are
133  'AK5', 'IC5', 'SC7', ...]
134  type : indicating the type of constituents of the jet [expec-
135  ted are 'Calo', 'PFlow', 'JPT', ...]
136  prefix : prefix indicating the type of pat collection module (ex-
137  pected are '', 'selected', 'clean').
138  ------------------------------------------------------------------
139  """
140  if(prefix==''):
141  jetCollectionString ='pat'
142  else:
143  jetCollectionString =prefix
144  jetCollectionString+='Pat'
145  jetCollectionString+='Jets'
146  jetCollectionString+=algo
147  jetCollectionString+=type
148  return jetCollectionString
149 
150 def contains(sequence, moduleName):
151  """
152  ------------------------------------------------------------------
153  return True if a module with name 'module' is contained in the
154  sequence with name 'sequence' and False otherwise. This version
155  is not so nice as it also returns True for any substr of the name
156  of a contained module.
157 
158  sequence : sequence [e.g. process.patHeavyIonDefaultSequence]
159  module : module name as a string
160  ------------------------------------------------------------------
161  """
162  return not sequence.__str__().find(moduleName)==-1
163 
164 
165 
166 def cloneProcessingSnippet(process, sequence, postfix):
167  """
168  ------------------------------------------------------------------
169  copy a sequence plus the modules and sequences therein
170  both are renamed by getting a postfix
171  input tags are automatically adjusted
172  ------------------------------------------------------------------
173  """
174  result = sequence
175  if not postfix == "":
176  visitor = CloneSequenceVisitor(process,sequence.label(),postfix)
177  sequence.visit(visitor)
178  result = visitor.clonedSequence()
179  return result
180 
181 if __name__=="__main__":
182  import unittest
183  class TestModuleCommand(unittest.TestCase):
184  def setUp(self):
185  """Nothing to do """
186  pass
187  def testCloning(self):
188  p = cms.Process("test")
189  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
190  p.b = cms.EDProducer("b", src=cms.InputTag("a"))
191  p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
192  p.s = cms.Sequence(p.a*p.b*p.c *p.a)
193  cloneProcessingSnippet(p, p.s, "New")
194  self.assertEqual(p.dumpPython(),'import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test")\n\nprocess.a = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.c = cms.EDProducer("c",\n src = cms.InputTag("b","instance")\n)\n\n\nprocess.cNew = cms.EDProducer("c",\n src = cms.InputTag("bNew","instance")\n)\n\n\nprocess.bNew = cms.EDProducer("b",\n src = cms.InputTag("aNew")\n)\n\n\nprocess.aNew = cms.EDProducer("a",\n src = cms.InputTag("gen")\n)\n\n\nprocess.b = cms.EDProducer("b",\n src = cms.InputTag("a")\n)\n\n\nprocess.s = cms.Sequence(process.a*process.b*process.c*process.a)\n\n\nprocess.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew)\n\n\n')
195  def testContains(self):
196  p = cms.Process("test")
197  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
198  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
199  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
200  p.s1 = cms.Sequence(p.a*p.b*p.c)
201  p.s2 = cms.Sequence(p.b*p.c)
202  self.assert_( contains(p.s1, "a") )
203  self.assert_( not contains(p.s2, "a") )
205  self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patFooBarJets')
206  self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatFooBarJets')
207  def testListModules(self):
208  p = cms.Process("test")
209  p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
210  p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
211  p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
212  p.s = cms.Sequence(p.a*p.b*p.c)
213  self.assertEqual([p.a,p.b,p.c], listModules(p.s))
214 
215  unittest.main()
def cloneProcessingSnippet(process, sequence, postfix)
def contains(sequence, moduleName)
def listModules(sequence)
def applyPostfix(process, label, postfix)
def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag, verbose=False, moduleLabelOnly=False, skipLabelTest=False)
Definition: MassReplace.py:73
def listSequences(sequence)
S & print(S &os, JobReport::InputFile const &f)
Definition: JobReport.cc:65
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:20
def __init__(self, gatheredInstance=cms._Module)
def jetCollectionString(prefix='', algo='', type='')
def removeIfInSequence(process, target, sequenceLabel, postfix="")
def remove(d, key, TELL=False)
Definition: MatrixUtil.py:212
def __init__(self, process, label, postfix)
def __labelsInSequence(process, sequenceLabel, postfix="")
def __appendToTopSequence(self, visitee)