41 header_files: list[str],
45 output_path: str |
None =
None,
46 template: str =
"$CMSSW_BASE/src/PhysicsTools/TensorFlowAOT/templates/wrapper.h.in",
50 for path
in header_files:
52 header_data[data.batch_size] = data
55 batch_sizes = sorted(data.batch_size
for data
in header_data.values())
59 "cmssw_version": os.environ[
"CMSSW_VERSION"],
60 "scram_arch": os.environ[
"SCRAM_ARCH"],
61 "model_path": model_path,
62 "batch_sizes": batch_sizes,
63 "subsystem": subsystem,
66 for key
in common_header_data:
67 values = set(getattr(d, key)
for d
in header_data.values())
69 raise ValueError(f
"found more than one possible {key} values: {', '.join(values)}")
70 variables[key] = values.pop()
73 def substituter(variables):
76 for key, value
in variables.items():
78 variables_[key] =
str(value)
79 if isinstance(value, str)
and not key.endswith(
"_UC"):
80 variables_[f
"{key}_UC"] = value.upper()
81 elif isinstance(value, (list, tuple))
and not key.endswith(
"_CSV"):
82 variables_[f
"{key}_CSV"] =
", ".
join(
map(str, value))
86 if key
not in variables_:
87 raise KeyError(f
"template contains unknown variable {key}")
88 return variables_[key]
90 return lambda line: re.sub(
r"\$\{([A-Z0-9_]+)\}", repl, line)
93 common_sub = substituter(variables)
95 batch_size : substituter({
97 **dict(
zip(HeaderData._fields, header_data[batch_size])),
99 for batch_size
in batch_sizes
103 template = os.path.expandvars(os.path.expanduser(
str(template)))
104 with open(template,
"r") as f: input_lines = [line.rstrip() for line
in f.readlines()]
109 line = input_lines.pop(0)
112 m = re.match(
r"^\/\/\s+foreach=([^\s]+)\s+lines=(\d+)$", line.strip())
115 n_lines =
int(m.group(2))
119 batch_lines, input_lines = input_lines[:n_lines], input_lines[n_lines:]
120 for batch_size
in batch_sizes:
121 for line
in batch_lines:
122 output_lines.append(model_subs[batch_size](line))
124 raise ValueError(f
"unknown loop target '{loop}'")
129 output_lines.append(common_sub(line))
133 output_path = f
"$CMSSW_BASE/src/{subsystem}/{package}/tfaot_dev/{variables['prefix']}.h" 134 output_path = os.path.expandvars(os.path.expanduser(
str(output_path)))
135 output_dir = os.path.dirname(output_path)
136 if not os.path.exists(output_dir):
137 os.makedirs(output_dir)
140 with open(output_path,
"w")
as f:
141 f.writelines(
"\n".
join(
map(str, output_lines)) +
"\n")
144
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE constexpr float zip(ConstView const &tracks, int32_t i)
static std::string join(char **cmd)