35 model_name: str |
None =
None,
36 model_version: str =
"1.0.0",
37 batch_sizes: tuple[int] = (1,),
38 output_path: str |
None =
None,
43 cmssw_base = os.getenv(
"CMSSW_BASE")
44 if not cmssw_base
or not os.path.isdir(cmssw_base):
45 raise Exception(
"CMSSW_BASE is not set or points to a non-existing directory")
47 output_path = os.path.join(
"$CMSSW_BASE",
"src", subsystem, package,
"tfaot_dev")
48 output_path = os.path.expandvars(os.path.expanduser(output_path))
51 model_path = os.path.expandvars(os.path.expanduser(model_path))
52 model_path = os.path.normpath(os.path.abspath(model_path))
53 if not os.path.exists(model_path):
54 raise Exception(f
"model_path '{model_path}' does not exist")
58 model_name = os.path.splitext(os.path.basename(model_path))[0]
61 lib_dir = os.path.join(output_path,
"lib")
62 if not os.path.exists(lib_dir):
64 inc_dir = os.path.join(output_path,
"include")
65 if not os.path.exists(inc_dir):
69 from cmsml.scripts.compile_tf_graph
import compile_tf_graph
70 with tempfile.TemporaryDirectory()
as tmp_dir:
72 model_path=model_path,
74 batch_sizes=batch_sizes,
75 compile_prefix=f
"{model_name}_bs{{}}",
76 compile_class=f
"{subsystem}_{package}::{model_name}_bs{{}}",
81 for bs
in batch_sizes:
82 header_name = f
"{model_name}_bs{bs}.h" 83 shutil.copy2(os.path.join(tmp_dir,
"aot", header_name), inc_dir)
84 shutil.copy2(os.path.join(tmp_dir,
"aot", f
"{model_name}_bs{bs}.o"), lib_dir)
85 header_files.append(os.path.join(inc_dir, header_name))
88 from create_wrapper
import create_wrapper
90 header_files=header_files,
91 model_path=model_path,
94 output_path=os.path.join(inc_dir, f
"{model_name}.h"),
99 "subsystem": subsystem,
100 "subsystem_uc": subsystem.upper(),
102 "package_uc": package.upper(),
103 "model_name": model_name,
104 "model_name_uc": model_name.upper(),
105 "model_version": model_version,
106 "lib_dir_name": os.path.basename(lib_dir),
107 "inc_dir_name": os.path.basename(inc_dir),
108 "tool_name": tool_name_template.format(
109 subsystem=subsystem.lower(),
110 package=package.lower(),
111 model_name=model_name.lower(),
113 "ld_flags":
"\n ".
join([
114 ld_flag_template.format(model_name=model_name, bs=bs)
115 for bs
in batch_sizes
118 tool_path = os.path.join(output_path, f
"{tool_vars['tool_name']}.xml")
119 with open(tool_path,
"w")
as f:
120 f.write(tool_file_template.format(**tool_vars))
123 tool_path_repr = os.path.relpath(tool_path)
124 if tool_path_repr.startswith(
".."):
125 tool_path_repr = tool_path
126 inc_path = f
"{output_path}/include/{model_name}.h" 127 if "CMSSW_BASE" in os.environ
and os.path.exists(os.environ[
"CMSSW_BASE"]):
128 inc_path_rel = os.path.relpath(inc_path, os.path.join(os.environ[
"CMSSW_BASE"],
"src"))
129 if not inc_path_rel.startswith(
".."):
130 inc_path = inc_path_rel
132 print(
"\n" + 80 *
"-" +
"\n")
133 print(f
"created custom tool file for AOT compiled model '{model_name}'")
134 print(
"to register it to scram, run")
135 print(f
"\n> scram setup {tool_path_repr}\n")
136 print(
"and use the following to include it in your code")
137 print(f
"\n#include \"{inc_path}\"\n")
void print(TMatrixD &m, const char *label=nullptr, bool mathematicaFormat=false)
static std::string join(char **cmd)