CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
TritonService.cc
Go to the documentation of this file.
3 
11 
12 #include "grpc_client.h"
13 #include "grpc_service.pb.h"
14 
15 #include <cstdio>
16 #include <cstdlib>
17 #include <filesystem>
18 #include <utility>
19 #include <tuple>
20 #include <unistd.h>
21 
22 namespace tc = triton::client;
23 
26 
27 namespace {
28  std::pair<std::string, int> execSys(const std::string& cmd) {
29  //redirect stderr to stdout
30  auto pipe = popen((cmd + " 2>&1").c_str(), "r");
31  int thisErrno = errno;
32  if (!pipe)
33  throw cms::Exception("SystemError") << "popen() failed with errno " << thisErrno << " for command: " << cmd;
34 
35  //extract output
36  constexpr static unsigned buffSize = 128;
37  std::array<char, buffSize> buffer;
39  while (!feof(pipe)) {
40  if (fgets(buffer.data(), buffSize, pipe))
41  result += buffer.data();
42  else {
43  thisErrno = ferror(pipe);
44  if (thisErrno)
45  throw cms::Exception("SystemError") << "failed reading command output with errno " << thisErrno;
46  }
47  }
48 
49  int rv = pclose(pipe);
50  return std::make_pair(result, rv);
51  }
52 } // namespace
53 
55  : verbose_(pset.getUntrackedParameter<bool>("verbose")),
56  fallbackOpts_(pset.getParameterSet("fallback")),
57  currentModuleId_(0),
58  allowAddModel_(false),
59  startedFallback_(false),
60  pid_(std::to_string(::getpid())) {
61  //module construction is assumed to be serial (correct at the time this code was written)
65  //fallback server will be launched (if needed) before beginJob
67 
68  //include fallback server in set if enabled
69  if (fallbackOpts_.enable) {
70  auto serverType = TritonServerType::Remote;
71  if (!fallbackOpts_.useGPU)
72  serverType = TritonServerType::LocalCPU;
73 #ifdef TRITON_ENABLE_GPU
74  else
75  serverType = TritonServerType::LocalGPU;
76 #endif
77 
78  servers_.emplace(std::piecewise_construct,
79  std::forward_as_tuple(Server::fallbackName),
80  std::forward_as_tuple(Server::fallbackName, Server::fallbackAddress, serverType));
81  }
82 
83  //loop over input servers: check which models they have
85  if (verbose_)
86  msg = "List of models for each server:\n";
87  for (const auto& serverPset : pset.getUntrackedParameterSetVector("servers")) {
88  const std::string& serverName(serverPset.getUntrackedParameter<std::string>("name"));
89  //ensure uniqueness
90  auto [sit, unique] = servers_.emplace(serverName, serverPset);
91  if (!unique)
92  throw cms::Exception("DuplicateServer")
93  << "Not allowed to specify more than one server with same name (" << serverName << ")";
94  auto& server(sit->second);
95 
96  std::unique_ptr<tc::InferenceServerGrpcClient> client;
98  tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions),
99  "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")");
100 
101  if (verbose_) {
102  inference::ServerMetadataResponse serverMetaResponse;
103  triton_utils::throwIfError(client->ServerMetadata(&serverMetaResponse),
104  "TritonService(): unable to get metadata for " + serverName + " (" + server.url + ")");
105  edm::LogInfo("TritonService") << "Server " << serverName << ": url = " << server.url
106  << ", version = " << serverMetaResponse.version();
107  }
108 
109  inference::RepositoryIndexResponse repoIndexResponse;
111  client->ModelRepositoryIndex(&repoIndexResponse),
112  "TritonService(): unable to get repository index for " + serverName + " (" + server.url + ")");
113 
114  //servers keep track of models and vice versa
115  if (verbose_)
116  msg += serverName + ": ";
117  for (const auto& modelIndex : repoIndexResponse.models()) {
118  const auto& modelName = modelIndex.name();
119  auto mit = models_.find(modelName);
120  if (mit == models_.end())
121  mit = models_.emplace(modelName, "").first;
122  auto& modelInfo(mit->second);
123  modelInfo.servers.insert(serverName);
124  server.models.insert(modelName);
125  if (verbose_)
126  msg += modelName + ", ";
127  }
128  if (verbose_)
129  msg += "\n";
130  }
131  if (verbose_)
132  edm::LogInfo("TritonService") << msg;
133 }
134 
136  currentModuleId_ = desc.id();
137  allowAddModel_ = true;
138 }
139 
140 void TritonService::addModel(const std::string& modelName, const std::string& path) {
141  //should only be called in module constructors
142  if (!allowAddModel_)
143  throw cms::Exception("DisallowedAddModel") << "Attempt to call addModel() outside of module constructors";
144  //if model is not in the list, then no specified server provides it
145  auto mit = models_.find(modelName);
146  if (mit == models_.end()) {
147  auto& modelInfo(unservedModels_.emplace(modelName, path).first->second);
148  modelInfo.modules.insert(currentModuleId_);
149  //only keep track of modules that need unserved models
150  modules_.emplace(currentModuleId_, modelName);
151  }
152 }
153 
155 
157  //remove destructed modules from unserved list
158  if (unservedModels_.empty())
159  return;
160  auto id = desc.id();
161  auto oit = modules_.find(id);
162  if (oit != modules_.end()) {
163  const auto& moduleInfo(oit->second);
164  auto mit = unservedModels_.find(moduleInfo.model);
165  if (mit != unservedModels_.end()) {
166  auto& modelInfo(mit->second);
167  modelInfo.modules.erase(id);
168  //remove a model if it is no longer needed by any modules
169  if (modelInfo.modules.empty())
170  unservedModels_.erase(mit);
171  }
172  modules_.erase(oit);
173  }
174 }
175 
176 //second return value is only true if fallback CPU server is being used
178  auto mit = models_.find(model);
179  if (mit == models_.end())
180  throw cms::Exception("MissingModel") << "There are no servers that provide model " << model;
181  const auto& modelInfo(mit->second);
182  const auto& modelServers = modelInfo.servers;
183 
184  auto msit = modelServers.end();
185  if (!preferred.empty()) {
186  msit = modelServers.find(preferred);
187  //todo: add a "strict" parameter to stop execution if preferred server isn't found?
188  if (msit == modelServers.end())
189  edm::LogWarning("PreferredServer") << "Preferred server " << preferred << " for model " << model
190  << " not available, will choose another server";
191  }
192  const auto& serverName(msit == modelServers.end() ? *modelServers.begin() : preferred);
193 
194  //todo: use some algorithm to select server rather than just picking arbitrarily
195  const auto& server(servers_.find(serverName)->second);
196  return server;
197 }
198 
200  //only need fallback if there are unserved models
201  if (!fallbackOpts_.enable or unservedModels_.empty())
202  return;
203 
205  if (verbose_)
206  msg = "List of models for fallback server: ";
207  //all unserved models are provided by fallback server
208  auto& server(servers_.find(Server::fallbackName)->second);
209  for (const auto& [modelName, model] : unservedModels_) {
210  auto& modelInfo(models_.emplace(modelName, model).first->second);
211  modelInfo.servers.insert(Server::fallbackName);
212  server.models.insert(modelName);
213  if (verbose_)
214  msg += modelName + ", ";
215  }
216  if (verbose_)
217  edm::LogInfo("TritonService") << msg;
218 
219  //assemble server start command
220  std::string command("cmsTriton -P -1 -p " + pid_);
221  if (fallbackOpts_.debug)
222  command += " -c";
224  command += " -v";
226  command += " -d";
227  if (fallbackOpts_.useGPU)
228  command += " -g";
229  if (!fallbackOpts_.instanceName.empty())
230  command += " -n " + fallbackOpts_.instanceName;
231  if (fallbackOpts_.retries >= 0)
232  command += " -r " + std::to_string(fallbackOpts_.retries);
233  if (fallbackOpts_.wait >= 0)
234  command += " -w " + std::to_string(fallbackOpts_.wait);
235  for (const auto& [modelName, model] : unservedModels_) {
236  command += " -m " + model.path;
237  }
238  if (!fallbackOpts_.imageName.empty())
239  command += " -i " + fallbackOpts_.imageName;
240  if (!fallbackOpts_.sandboxName.empty())
241  command += " -s " + fallbackOpts_.sandboxName;
242  //don't need this anymore
243  unservedModels_.clear();
244 
245  //get a random temporary directory if none specified
246  if (fallbackOpts_.tempDir.empty()) {
247  auto tmp_dir_path{std::filesystem::temp_directory_path() /= edm::createGlobalIdentifier()};
248  fallbackOpts_.tempDir = tmp_dir_path.string();
249  }
250  //special case ".": use script default (temp dir = .$instanceName)
251  if (fallbackOpts_.tempDir != ".")
252  command += " -t " + fallbackOpts_.tempDir;
253 
254  command += " start";
255 
256  if (fallbackOpts_.debug)
257  edm::LogInfo("TritonService") << "Fallback server temporary directory: " << fallbackOpts_.tempDir;
258  if (verbose_)
259  edm::LogInfo("TritonService") << command;
260 
261  //mark as started before executing in case of ctrl+c while command is running
262  startedFallback_ = true;
263  const auto& [output, rv] = execSys(command);
264  if (verbose_ or rv != 0)
265  edm::LogInfo("TritonService") << output;
266  if (rv != 0)
267  throw cms::Exception("FallbackFailed") << "Starting the fallback server failed with exit code " << rv;
268 
269  //get the port
270  const std::string& portIndicator("CMS_TRITON_GRPC_PORT: ");
271  //find last instance in log in case multiple ports were tried
272  auto pos = output.rfind(portIndicator);
273  if (pos != std::string::npos) {
274  auto pos2 = pos + portIndicator.size();
275  auto pos3 = output.find('\n', pos2);
276  const auto& portNum = output.substr(pos2, pos3 - pos2);
277  server.url += ":" + portNum;
278  } else
279  throw cms::Exception("FallbackFailed") << "Unknown port for fallback server, log follows:\n" << output;
280 }
281 
284  desc.addUntracked<bool>("verbose", false);
285 
287  validator.addUntracked<std::string>("name");
288  validator.addUntracked<std::string>("address");
289  validator.addUntracked<unsigned>("port");
290  validator.addUntracked<bool>("useSsl", false);
291  validator.addUntracked<std::string>("rootCertificates", "");
292  validator.addUntracked<std::string>("privateKey", "");
293  validator.addUntracked<std::string>("certificateChain", "");
294 
295  desc.addVPSetUntracked("servers", validator, {});
296 
297  edm::ParameterSetDescription fallbackDesc;
298  fallbackDesc.addUntracked<bool>("enable", false);
299  fallbackDesc.addUntracked<bool>("debug", false);
300  fallbackDesc.addUntracked<bool>("verbose", false);
301  fallbackDesc.addUntracked<bool>("useDocker", false);
302  fallbackDesc.addUntracked<bool>("useGPU", false);
303  fallbackDesc.addUntracked<int>("retries", -1);
304  fallbackDesc.addUntracked<int>("wait", -1);
305  fallbackDesc.addUntracked<std::string>("instanceBaseName", "triton_server_instance");
306  fallbackDesc.addUntracked<std::string>("instanceName", "");
307  fallbackDesc.addUntracked<std::string>("tempDir", "");
308  fallbackDesc.addUntracked<std::string>("imageName", "");
309  fallbackDesc.addUntracked<std::string>("sandboxName", "");
310  desc.add<edm::ParameterSetDescription>("fallback", fallbackDesc);
311 
312  descriptions.addWithDefaultLabel(desc);
313 }
void watchPostModuleConstruction(PostModuleConstruction::slot_type const &iSlot)
void addWithDefaultLabel(ParameterSetDescription const &psetDescription)
std::unordered_map< std::string, Model > models_
Server serverInfo(const std::string &model, const std::string &preferred="") const
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e< void, edm::EventIDconst &, edm::Timestampconst & > We also list in braces which AR_WATCH_USING_METHOD_ is used for those or
Definition: Activities.doc:12
std::unordered_map< std::string, Model > unservedModels_
void watchPreModuleConstruction(PreModuleConstruction::slot_type const &iSlot)
ParameterSet const & getParameterSet(ParameterSetID const &id)
void postModuleConstruction(edm::ModuleDescription const &)
def unique
Definition: tier0.py:24
void watchPreModuleDestruction(PreModuleDestruction::slot_type const &iSlot)
def pipe
Definition: pipe.py:5
unsigned currentModuleId_
TritonService(const edm::ParameterSet &pset, edm::ActivityRegistry &areg)
static const std::string fallbackAddress
Definition: TritonService.h:86
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:21
void addModel(const std::string &modelName, const std::string &path)
tuple result
Definition: mps_fire.py:311
FallbackOpts fallbackOpts_
if(conf_.getParameter< bool >("UseStripCablingDB"))
std::string pid_
void preBeginJob(edm::PathsAndConsumesOfModulesBase const &, edm::ProcessContext const &)
ParameterDescriptionBase * add(U const &iLabel, T const &value)
void preModuleConstruction(edm::ModuleDescription const &)
areg
Definition: Schedule.cc:687
std::string createGlobalIdentifier(bool binary=false)
Log< level::Info, false > LogInfo
void preModuleDestruction(edm::ModuleDescription const &)
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
std::unordered_map< unsigned, Module > modules_
tuple msg
Definition: mps_check.py:285
void watchPreBeginJob(PreBeginJob::slot_type const &iSlot)
convenience function for attaching to signal
list command
Definition: mps_check.py:25
list cmd
Definition: mps_setup.py:244
VParameterSet getUntrackedParameterSetVector(std::string const &name, VParameterSet const &defaultValue) const
ParameterDescriptionBase * addVPSetUntracked(U const &iLabel, ParameterSetDescription const &validator, std::vector< ParameterSet > const &defaults)
std::unordered_map< std::string, Server > servers_
static const std::string fallbackName
Definition: TritonService.h:85
unsigned int id() const