CMS 3D CMS Logo

Functions | Variables
create_wrapper Namespace Reference

Functions

def create_wrapper
 
def main ()
 
def parse_header
 

Variables

 common_header_data
 
 HeaderData
 

Function Documentation

◆ create_wrapper()

def create_wrapper.create_wrapper (   header_files)

Definition at line 41 of file create_wrapper.py.

References createfilelist.int, join(), genParticles_cff.map, parse_header(), str, and reco.zip().

41  header_files: list[str],
42  model_path: str,
43  subsystem: str,
44  package: str,
45  output_path: str | None = None,
46  template: str = "$CMSSW_BASE/src/PhysicsTools/TensorFlowAOT/templates/wrapper.h.in",
47 ) -> None:
48  # read header data
49  header_data = {}
50  for path in header_files:
51  data = parse_header(path)
52  header_data[data.batch_size] = data
53 
54  # sorted batch sizes
55  batch_sizes = sorted(data.batch_size for data in header_data.values())
56 
57  # set common variables
58  variables = {
59  "cmssw_version": os.environ["CMSSW_VERSION"],
60  "scram_arch": os.environ["SCRAM_ARCH"],
61  "model_path": model_path,
62  "batch_sizes": batch_sizes,
63  "subsystem": subsystem,
64  "package": package,
65  }
66  for key in common_header_data:
67  values = set(getattr(d, key) for d in header_data.values())
68  if len(values) > 1:
69  raise ValueError(f"found more than one possible {key} values: {', '.join(values)}")
70  variables[key] = values.pop()
71 
72  # helper for variable replacement
73  def substituter(variables):
74  # insert upper-case variants of strings, csv variants of lists
75  variables_ = {}
76  for key, value in variables.items():
77  key = key.upper()
78  variables_[key] = str(value)
79  if isinstance(value, str) and not key.endswith("_UC"):
80  variables_[f"{key}_UC"] = value.upper()
81  elif isinstance(value, (list, tuple)) and not key.endswith("_CSV"):
82  variables_[f"{key}_CSV"] = ", ".join(map(str, value))
83 
84  def repl(m):
85  key = m.group(1)
86  if key not in variables_:
87  raise KeyError(f"template contains unknown variable {key}")
88  return variables_[key]
89 
90  return lambda line: re.sub(r"\$\{([A-Z0-9_]+)\}", repl, line)
91 
92  # substituter for common variables and per-model variables
93  common_sub = substituter(variables)
94  model_subs = {
95  batch_size : substituter({
96  **variables,
97  **dict(zip(HeaderData._fields, header_data[batch_size])),
98  })
99  for batch_size in batch_sizes
100  }
101 
102  # read template lines
103  template = os.path.expandvars(os.path.expanduser(str(template)))
104  with open(template, "r") as f: input_lines = [line.rstrip() for line in f.readlines()]
105 
106  # go through lines and define new ones
107  output_lines = []
108  while input_lines:
109  line = input_lines.pop(0)
110 
111  # loop statement?
112  m = re.match(r"^\/\/\s+foreach=([^\s]+)\s+lines=(\d+)$", line.strip())
113  if m:
114  loop = m.group(1)
115  n_lines = int(m.group(2))
116 
117  if loop == "MODEL":
118  # repeat the next n lines for each batch size and replace model variables
119  batch_lines, input_lines = input_lines[:n_lines], input_lines[n_lines:]
120  for batch_size in batch_sizes:
121  for line in batch_lines:
122  output_lines.append(model_subs[batch_size](line))
123  else:
124  raise ValueError(f"unknown loop target '{loop}'")
125 
126  continue
127 
128  # just make common substitutions
129  output_lines.append(common_sub(line))
130 
131  # prepare the output
132  if not output_path:
133  output_path = f"$CMSSW_BASE/src/{subsystem}/{package}/tfaot_dev/{variables['prefix']}.h"
134  output_path = os.path.expandvars(os.path.expanduser(str(output_path)))
135  output_dir = os.path.dirname(output_path)
136  if not os.path.exists(output_dir):
137  os.makedirs(output_dir)
138 
139  # write lines
140  with open(output_path, "w") as f:
141  f.writelines("\n".join(map(str, output_lines)) + "\n")
142 
143 
144 
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE constexpr float zip(ConstView const &tracks, int32_t i)
Definition: TracksSoA.h:90
static std::string join(char **cmd)
Definition: RemoteFile.cc:19
#define str(s)

◆ main()

def create_wrapper.main (   None)

Definition at line 234 of file create_wrapper.py.

234 def main() -> None:
235  from argparse import ArgumentParser
236 
237  parser = ArgumentParser(
238  description=__doc__.strip(),
239  )
240  parser.add_argument(
241  "--subsystem",
242  "-s",
243  required=True,
244  help="the CMSSW subsystem that the plugin belongs to",
245  )
246  parser.add_argument(
247  "--package",
248  "-p",
249  required=True,
250  help="the CMSSW package that the plugin belongs to",
251  )
252  parser.add_argument(
253  "--model-path",
254  "-m",
255  required=True,
256  help="path of the initial model file for provenance purposes",
257  )
258  parser.add_argument(
259  "--header-files",
260  "-f",
261  required=True,
262  nargs="+",
263  help="comma-separated list of AOT header files that define the models to wrap",
264  )
265  parser.add_argument(
266  "--output-path",
267  "-o",
268  help="path where the created header file should be saved; default: "
269  "$CMSSW_BASE/src/SUBSYSTEM/PACKAGE/tfaot_dev/PREFIX.h"
270  )
271  args = parser.parse_args()
272 
274  header_files=args.header_files,
275  model_path=args.model_path,
276  subsystem=args.subsystem,
277  package=args.package,
278  output_path=args.output_path,
279  )
280 
281 

◆ parse_header()

def create_wrapper.parse_header (   path)

Definition at line 145 of file create_wrapper.py.

References cmsswConfigtrace.flatten(), HeaderData, createfilelist.int, FastTimerService_cff.range, str, and mkLumiAveragedPlots.tuple.

Referenced by create_wrapper().

145 def parse_header(path: str) -> HeaderData:
146  # read all non-empty lines
147  path = os.path.expandvars(os.path.expanduser(str(path)))
148  with open(path, "r") as f:
149  lines = [line for line in (line.strip() for line in f.readlines()) if line]
150 
151  # prepare HeaderData
152  data = HeaderData(*([None] * len(HeaderData._fields)))
153 
154  # helper to set data fields
155  set_ = lambda key, value: data._replace(**{key: value})
156 
157  # extract data
158  arg_counts = {}
159  res_counts = {}
160  while lines:
161  line = lines.pop(0)
162 
163  # read the namespace
164  m = re.match(r"^namespace\s+([^\s]+)\s*\{$", line)
165  if m:
166  data = set_("namespace", m.group(1))
167  continue
168 
169  # read the class name and batch size
170  m = re.match(rf"^class\s+([^\s]+)_bs(\d+)\s+final\s+\:\s+public\stensorflow\:\:XlaCompiledCpuFunction\s+.*$", line) # noqa
171  if m:
172  data = set_("class_name", m.group(1))
173  data = set_("batch_size", int(m.group(2)))
174 
175  # read argument and result counts
176  m = re.match(r"^int\s+(arg|result)(\d+)_count\(\).+$", line)
177  if m:
178  # get kind and index
179  kind = m.group(1)
180  index = int(m.group(2))
181 
182  # parse the next line
183  m = re.match(r"^return\s+(\d+)\s*\;.*$", lines.pop(0))
184  if not m:
185  raise Exception(f"corrupted header file {path}")
186  count = int(m.group(1))
187 
188  # store the count
189  (arg_counts if kind == "arg" else res_counts)[index] = count
190  continue
191 
192  # helper to flatten counts to lists
193  def flatten(counts: dict[int, int], name: str) -> list[int]:
194  if set(counts) != set(range(len(counts))):
195  raise ValueError(
196  f"non-contiguous indices in {name} counts: {', '.join(map(str, counts))}",
197  )
198  return [counts[index] for index in sorted(counts)]
199 
200 
201  # helper to enforce integer division by batch size
202  def no_batch(count: int, index: int, name: str) -> int:
203  if count % data.batch_size != 0:
204  raise ValueError(
205  f"{name} count of {count} at index {index} is not dividable by batch size "
206  f"{data.batch_size}",
207  )
208  return count // data.batch_size
209 
210  # store the prefix
211  base = os.path.basename(path)
212  postfix = f"_bs{data.batch_size}.h"
213  if not base.endswith(postfix):
214  raise ValueError(f"header '{path}' does not end with expected postfix '{postfix}'")
215  data = set_("prefix", base[:-len(postfix)])
216 
217  # set counts
218  data = set_("n_args", len(arg_counts))
219  data = set_("n_res", len(res_counts))
220  data = set_("arg_counts", flatten(arg_counts, "argument"))
221  data = set_("res_counts", flatten(res_counts, "result"))
222  data = set_("arg_counts_no_batch", tuple(
223  no_batch(c, i, "argument")
224  for i, c in enumerate(data.arg_counts)
225  ))
226  data = set_("res_counts_no_batch", tuple(
227  no_batch(c, i, "result")
228  for i, c in enumerate(data.res_counts)
229  ))
230 
231  return data
232 
233 
#define str(s)

Variable Documentation

◆ common_header_data

create_wrapper.common_header_data

Definition at line 29 of file create_wrapper.py.

◆ HeaderData

create_wrapper.HeaderData

Definition at line 16 of file create_wrapper.py.

Referenced by parse_header().