CMS 3D CMS Logo

TritonService.cc
Go to the documentation of this file.
3 
14 
15 #include "grpc_client.h"
16 #include "grpc_service.pb.h"
17 
18 #include <algorithm>
19 #include <cctype>
20 #include <cstdio>
21 #include <cstdlib>
22 #include <filesystem>
23 #include <fstream>
24 #include <utility>
25 #include <tuple>
26 #include <unistd.h>
27 
28 namespace tc = triton::client;
29 
32 const std::string TritonService::Server::siteconfName{"SONIC_LOCAL_BALANCER"};
33 
34 namespace {
35  std::pair<std::string, int> execSys(const std::string& cmd) {
36  //redirect stderr to stdout
37  auto pipe = popen((cmd + " 2>&1").c_str(), "r");
38  int thisErrno = errno;
39  if (!pipe)
40  throw cms::Exception("SystemError")
41  << "TritonService: popen() failed with errno " << thisErrno << " for command: " << cmd;
42 
43  //extract output
44  constexpr static unsigned buffSize = 128;
45  std::array<char, buffSize> buffer;
47  while (!feof(pipe)) {
48  if (fgets(buffer.data(), buffSize, pipe))
49  result += buffer.data();
50  else {
51  thisErrno = ferror(pipe);
52  if (thisErrno)
53  throw cms::Exception("SystemError")
54  << "TritonService: failed reading command output with errno " << thisErrno;
55  }
56  }
57 
58  int rv = pclose(pipe);
59  return std::make_pair(result, rv);
60  }
61 
62  //extract specific info from log
63  std::string extractFromLog(const std::string& output, const std::string& indicator) {
64  //find last instance in log (in case of multiple)
65  auto pos = output.rfind(indicator);
66  if (pos != std::string::npos) {
67  auto pos2 = pos + indicator.size();
68  auto pos3 = output.find('\n', pos2);
69  return output.substr(pos2, pos3 - pos2);
70  } else
71  return "";
72  }
73 } // namespace
74 
76  : verbose_(pset.getUntrackedParameter<bool>("verbose")),
77  fallbackOpts_(pset.getParameterSet("fallback")),
78  currentModuleId_(0),
79  allowAddModel_(false),
80  startedFallback_(false),
81  callFails_(0),
82  pid_(std::to_string(::getpid())) {
83  //module construction is assumed to be serial (correct at the time this code was written)
84 
86 
90  //fallback server will be launched (if needed) before beginJob
93 
94  //check for server specified in SITECONF
95  //(temporary solution, to be replaced with entry in site-local-config.xml or similar)
98  if (!siteconf_address.empty() and !siteconf_port.empty()) {
99  servers_.emplace(
100  std::piecewise_construct,
101  std::forward_as_tuple(Server::siteconfName),
102  std::forward_as_tuple(Server::siteconfName, siteconf_address + ":" + siteconf_port, TritonServerType::Remote));
103  if (verbose_)
104  edm::LogInfo("TritonDiscovery") << "Obtained server from SITECONF: "
105  << servers_.find(Server::siteconfName)->second.url;
106  } else if (siteconf_address.empty() != siteconf_port.empty()) { //xor
107  edm::LogWarning("TritonDiscovery") << "Incomplete server information from SITECONF: HOST = " << siteconf_address
108  << ", PORT = " << siteconf_port;
109  } else
110  edm::LogWarning("TritonDiscovery") << "No server information from SITECONF";
111 
112  //finally, populate list of servers from config input
113  for (const auto& serverPset : pset.getUntrackedParameterSetVector("servers")) {
114  const std::string& serverName(serverPset.getUntrackedParameter<std::string>("name"));
115  //ensure uniqueness
116  auto [sit, unique] = servers_.emplace(serverName, serverPset);
117  if (!unique)
118  throw cms::Exception("DuplicateServer")
119  << "TritonService: Not allowed to specify more than one server with same name (" << serverName << ")";
120  }
121 
122  //loop over all servers: check which models they have
124  if (verbose_)
125  msg = "List of models for each server:\n";
126  for (auto& [serverName, server] : servers_) {
127  std::unique_ptr<tc::InferenceServerGrpcClient> client;
129  tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions),
130  "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")",
131  false);
132 
133  if (verbose_) {
134  inference::ServerMetadataResponse serverMetaResponse;
135  auto err = client->ServerMetadata(&serverMetaResponse);
136  if (err.IsOk())
137  edm::LogInfo("TritonService") << "Server " << serverName << ": url = " << server.url
138  << ", version = " << serverMetaResponse.version();
139  else
140  edm::LogInfo("TritonService") << "unable to get metadata for " + serverName + " (" + server.url + ")";
141  }
142 
143  //if this query fails, it indicates that the server is nonresponsive or saturated
144  //in which case it should just be skipped
145  inference::RepositoryIndexResponse repoIndexResponse;
146  auto err = client->ModelRepositoryIndex(&repoIndexResponse);
147 
148  //servers keep track of models and vice versa
149  if (verbose_)
150  msg += serverName + ": ";
151  if (err.IsOk()) {
152  for (const auto& modelIndex : repoIndexResponse.models()) {
153  const auto& modelName = modelIndex.name();
154  auto mit = models_.find(modelName);
155  if (mit == models_.end())
156  mit = models_.emplace(modelName, "").first;
157  auto& modelInfo(mit->second);
158  modelInfo.servers.insert(serverName);
159  server.models.insert(modelName);
160  if (verbose_)
161  msg += modelName + ", ";
162  }
163  } else {
164  if (verbose_)
165  msg += "unable to get repository index";
166  else
167  edm::LogWarning("TritonFailure") << "TritonService(): unable to get repository index for " + serverName + " (" +
168  server.url + ")";
169  }
170  if (verbose_)
171  msg += "\n";
172  }
173  if (verbose_)
174  edm::LogInfo("TritonDiscovery") << msg;
175 }
176 
178  numberOfThreads_ = bounds.maxNumberOfThreads();
179 }
180 
182  currentModuleId_ = desc.id();
183  allowAddModel_ = true;
184 }
185 
187  //should only be called in module constructors
188  if (!allowAddModel_)
189  throw cms::Exception("DisallowedAddModel")
190  << "TritonService: Attempt to call addModel() outside of module constructors";
191  //if model is not in the list, then no specified server provides it
192  auto mit = models_.find(modelName);
193  if (mit == models_.end()) {
194  auto& modelInfo(unservedModels_.emplace(modelName, path).first->second);
195  modelInfo.modules.insert(currentModuleId_);
196  //only keep track of modules that need unserved models
198  }
199 }
200 
202 
204  //remove destructed modules from unserved list
205  if (unservedModels_.empty())
206  return;
207  auto id = desc.id();
208  auto oit = modules_.find(id);
209  if (oit != modules_.end()) {
210  const auto& moduleInfo(oit->second);
211  auto mit = unservedModels_.find(moduleInfo.model);
212  if (mit != unservedModels_.end()) {
213  auto& modelInfo(mit->second);
214  modelInfo.modules.erase(id);
215  //remove a model if it is no longer needed by any modules
216  if (modelInfo.modules.empty())
217  unservedModels_.erase(mit);
218  }
219  modules_.erase(oit);
220  }
221 }
222 
223 //second return value is only true if fallback CPU server is being used
225  auto mit = models_.find(model);
226  if (mit == models_.end())
227  throw cms::Exception("MissingModel") << "TritonService: There are no servers that provide model " << model;
228  const auto& modelInfo(mit->second);
229  const auto& modelServers = modelInfo.servers;
230 
231  auto msit = modelServers.end();
232  if (!preferred.empty()) {
233  msit = modelServers.find(preferred);
234  //todo: add a "strict" parameter to stop execution if preferred server isn't found?
235  if (msit == modelServers.end())
236  edm::LogWarning("PreferredServer") << "Preferred server " << preferred << " for model " << model
237  << " not available, will choose another server";
238  }
239  const auto& serverName(msit == modelServers.end() ? *modelServers.begin() : preferred);
240 
241  //todo: use some algorithm to select server rather than just picking arbitrarily
242  const auto& server(servers_.find(serverName)->second);
243  return server;
244 }
245 
247  //only need fallback if there are unserved models
248  if (!fallbackOpts_.enable or unservedModels_.empty())
249  return;
250 
251  //include fallback server in set
252  auto serverType = TritonServerType::LocalCPU;
253  if (fallbackOpts_.device == "gpu")
254  serverType = TritonServerType::LocalGPU;
255  servers_.emplace(std::piecewise_construct,
256  std::forward_as_tuple(Server::fallbackName),
257  std::forward_as_tuple(Server::fallbackName, Server::fallbackAddress, serverType));
258 
260  if (verbose_)
261  msg = "List of models for fallback server: ";
262  //all unserved models are provided by fallback server
263  auto& server(servers_.find(Server::fallbackName)->second);
264  for (const auto& [modelName, model] : unservedModels_) {
265  auto& modelInfo(models_.emplace(modelName, model).first->second);
266  modelInfo.servers.insert(Server::fallbackName);
267  server.models.insert(modelName);
268  if (verbose_)
269  msg += modelName + ", ";
270  }
271  if (verbose_)
272  edm::LogInfo("TritonDiscovery") << msg;
273 
274  //assemble server start command
275  fallbackOpts_.command = "cmsTriton -P -1 -p " + pid_;
278  if (fallbackOpts_.debug)
279  fallbackOpts_.command += " -c";
281  fallbackOpts_.command += " -v";
282  if (!fallbackOpts_.instanceName.empty())
284  if (fallbackOpts_.retries >= 0)
286  if (fallbackOpts_.wait >= 0)
288  for (const auto& [modelName, model] : unservedModels_) {
289  fallbackOpts_.command += " -m " + model.path;
290  }
291  std::string thread_string = " -I " + std::to_string(numberOfThreads_);
292  fallbackOpts_.command += thread_string;
293  if (!fallbackOpts_.imageName.empty())
295  if (!fallbackOpts_.sandboxName.empty())
297  //don't need this anymore
298  unservedModels_.clear();
299 
300  //get a random temporary directory if none specified
301  if (fallbackOpts_.tempDir.empty()) {
302  auto tmp_dir_path{std::filesystem::temp_directory_path() /= edm::createGlobalIdentifier()};
303  fallbackOpts_.tempDir = tmp_dir_path.string();
304  }
305  //special case ".": use script default (temp dir = .$instanceName)
306  if (fallbackOpts_.tempDir != ".")
308 
310 
311  if (fallbackOpts_.debug)
312  edm::LogInfo("TritonService") << "Fallback server temporary directory: " << fallbackOpts_.tempDir;
313  if (verbose_)
314  edm::LogInfo("TritonService") << command;
315 
316  //mark as started before executing in case of ctrl+c while command is running
317  startedFallback_ = true;
318  const auto& [output, rv] = execSys(command);
319  if (rv != 0) {
320  edm::LogError("TritonService") << output;
321  printFallbackServerLog<edm::LogError>();
323  << "TritonService: Starting the fallback server failed with exit code " << rv;
324  } else if (verbose_)
325  edm::LogInfo("TritonService") << output;
326 
327  //get the chosen device
328  std::string chosenDevice(fallbackOpts_.device);
329  if (chosenDevice == "auto") {
330  chosenDevice = extractFromLog(output, "CMS_TRITON_CHOSEN_DEVICE: ");
331  if (!chosenDevice.empty()) {
332  if (chosenDevice == "cpu")
334  else if (chosenDevice == "gpu")
336  else
338  << "TritonService: unsupported device choice " << chosenDevice << " for fallback server, log follows:\n"
339  << output;
340  } else
342  << "TritonService: unknown device choice for fallback server, log follows:\n"
343  << output;
344  }
345  //print server info
346  std::transform(chosenDevice.begin(), chosenDevice.end(), chosenDevice.begin(), toupper);
347  if (verbose_)
348  edm::LogInfo("TritonDiscovery") << "Fallback server started: " << chosenDevice;
349 
350  //get the port
351  const auto& portNum = extractFromLog(output, "CMS_TRITON_GRPC_PORT: ");
352  if (!portNum.empty())
353  server.url += ":" + portNum;
354  else
356  << "TritonService: Unknown port for fallback server, log follows:\n"
357  << output;
358 }
359 
361  if (status)
362  --callFails_;
363  else
364  ++callFails_;
365 }
366 
368  if (!startedFallback_)
369  return;
370 
372  //prevent log cleanup during server stop
373  if (callFails_ > 0)
374  command += " -c";
375  command += " stop";
376  if (verbose_)
377  edm::LogInfo("TritonService") << command;
378 
379  const auto& [output, rv] = execSys(command);
380  if (rv != 0 or callFails_ > 0) {
381  //print logs if cmsRun is currently exiting because of a TritonException
382  edm::LogError("TritonService") << output;
383  printFallbackServerLog<edm::LogError>();
384  if (rv != 0) {
385  std::string stopCat("FallbackFailed");
386  std::string stopMsg = fmt::format("TritonService: Stopping the fallback server failed with exit code {}", rv);
387  //avoid throwing if the stack is already unwinding
388  if (callFails_ > 0)
389  edm::LogWarning(stopCat) << stopMsg;
390  else
391  throw cms::Exception(stopCat) << stopMsg;
392  }
393  } else if (verbose_) {
394  edm::LogInfo("TritonService") << output;
395  printFallbackServerLog<edm::LogInfo>();
396  }
397 }
398 
399 template <typename LOG>
401  std::vector<std::string> logNames{"log_" + fallbackOpts_.instanceName + ".log"};
402  //cmsTriton script moves log from temp to current dir in verbose mode or in some cases when auto_stop is called
403  // -> check both places
404  logNames.push_back(fallbackOpts_.tempDir + "/" + logNames[0]);
405  bool foundLog = false;
406  for (const auto& logName : logNames) {
407  std::ifstream infile(logName);
408  if (infile.is_open()) {
409  LOG("TritonService") << "TritonService: server log " << logName << "\n" << infile.rdbuf();
410  foundLog = true;
411  break;
412  }
413  }
414  if (!foundLog)
415  LOG("TritonService") << "TritonService: could not find server log " << logNames[0] << " in current directory or "
417 }
418 
421  desc.addUntracked<bool>("verbose", false);
422 
424  validator.addUntracked<std::string>("name");
425  validator.addUntracked<std::string>("address");
426  validator.addUntracked<unsigned>("port");
427  validator.addUntracked<bool>("useSsl", false);
428  validator.addUntracked<std::string>("rootCertificates", "");
429  validator.addUntracked<std::string>("privateKey", "");
430  validator.addUntracked<std::string>("certificateChain", "");
431 
432  desc.addVPSetUntracked("servers", validator, {});
433 
434  edm::ParameterSetDescription fallbackDesc;
435  fallbackDesc.addUntracked<bool>("enable", false);
436  fallbackDesc.addUntracked<bool>("debug", false);
437  fallbackDesc.addUntracked<bool>("verbose", false);
438  fallbackDesc.ifValue(edm::ParameterDescription<std::string>("container", "apptainer", false),
439  edm::allowedValues<std::string>("apptainer", "docker", "podman"));
440  fallbackDesc.ifValue(edm::ParameterDescription<std::string>("device", "auto", false),
441  edm::allowedValues<std::string>("auto", "cpu", "gpu"));
442  fallbackDesc.addUntracked<int>("retries", -1);
443  fallbackDesc.addUntracked<int>("wait", -1);
444  fallbackDesc.addUntracked<std::string>("instanceBaseName", "triton_server_instance");
445  fallbackDesc.addUntracked<std::string>("instanceName", "");
446  fallbackDesc.addUntracked<std::string>("tempDir", "");
447  fallbackDesc.addUntracked<std::string>("imageName", "");
448  fallbackDesc.addUntracked<std::string>("sandboxName", "");
449  desc.add<edm::ParameterSetDescription>("fallback", fallbackDesc);
450 
451  descriptions.addWithDefaultLabel(desc);
452 }
void watchPostModuleConstruction(PostModuleConstruction::slot_type const &iSlot)
ParameterDescriptionNode * ifValue(ParameterDescription< T > const &switchParameter, std::unique_ptr< ParameterDescriptionCases< T >> cases)
void addWithDefaultLabel(ParameterSetDescription const &psetDescription)
std::unordered_map< std::string, Model > models_
static const std::string siteconfName
Definition: TritonService.h:92
void watchPreallocate(Preallocate::slot_type const &iSlot)
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
void watchPostEndJob(PostEndJob::slot_type const &iSlot)
void notifyCallStatus(bool status) const
std::unordered_map< std::string, Model > unservedModels_
void watchPreModuleConstruction(PreModuleConstruction::slot_type const &iSlot)
void postModuleConstruction(edm::ModuleDescription const &)
void watchPreModuleDestruction(PreModuleDestruction::slot_type const &iSlot)
unsigned currentModuleId_
void preallocate(edm::service::SystemBounds const &)
#define LOG(A)
Log< level::Error, false > LogError
TritonService(const edm::ParameterSet &pset, edm::ActivityRegistry &areg)
static const std::string fallbackAddress
Definition: TritonService.h:91
static std::string to_string(const XMLCh *ch)
void addModel(const std::string &modelName, const std::string &path)
FallbackOpts fallbackOpts_
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e< void, edm::EventID const &, edm::Timestamp const & > We also list in braces which AR_WATCH_USING_METHOD_ is used for those or
Definition: Activities.doc:12
def unique(seq, keepstr=True)
Definition: tier0.py:24
std::string pid_
void preBeginJob(edm::PathsAndConsumesOfModulesBase const &, edm::ProcessContext const &)
void printFallbackServerLog() const
def pipe(cmdline, input=None)
Definition: pipe.py:5
#define TRITON_THROW_IF_ERROR(X, MSG, NOTIFY)
Definition: triton_utils.h:78
void preModuleConstruction(edm::ModuleDescription const &)
std::string createGlobalIdentifier(bool binary=false)
Log< level::Info, false > LogInfo
void preModuleDestruction(edm::ModuleDescription const &)
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
std::unordered_map< unsigned, Module > modules_
std::atomic< int > callFails_
tuple msg
Definition: mps_check.py:286
void watchPreBeginJob(PreBeginJob::slot_type const &iSlot)
convenience function for attaching to signal
ParameterSet const & getParameterSet(ParameterSetID const &id)
list command
Definition: mps_check.py:25
Server serverInfo(const std::string &model, const std::string &preferred="") const
list cmd
Definition: mps_setup.py:244
Definition: output.py:1
Log< level::Warning, false > LogWarning
if(threadIdxLocalY==0 &&threadIdxLocalX==0)
Definition: pipe.py:1
std::unordered_map< std::string, Server > servers_
std::string getEnvironmentVariable(std::string const &name, std::string const &defaultValue=std::string())
static const std::string fallbackName
Definition: TritonService.h:90
Definition: server.py:1
unsigned transform(const HcalDetId &id, unsigned transformCode)