CMS 3D CMS Logo

TritonClient.cc
Go to the documentation of this file.
9 
10 #include "grpc_client.h"
11 #include "grpc_service.pb.h"
12 #include "model_config.pb.h"
13 
14 #include "google/protobuf/text_format.h"
15 #include "google/protobuf/io/zero_copy_stream_impl.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <exception>
20 #include <experimental/iterator>
21 #include <fcntl.h>
22 #include <sstream>
23 #include <string>
24 #include <utility>
25 #include <tuple>
26 
27 namespace tc = triton::client;
28 
29 namespace {
30  grpc_compression_algorithm getCompressionAlgo(const std::string& name) {
31  if (name.empty() or name.compare("none") == 0)
32  return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
33  else if (name.compare("deflate") == 0)
34  return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
35  else if (name.compare("gzip") == 0)
36  return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
37  else
38  throw cms::Exception("GrpcCompression")
39  << "Unknown compression algorithm requested: " << name << " (choices: none, deflate, gzip)";
40  }
41 
42  std::vector<std::shared_ptr<tc::InferResult>> convertToShared(const std::vector<tc::InferResult*>& tmp) {
43  std::vector<std::shared_ptr<tc::InferResult>> results;
44  results.reserve(tmp.size());
45  std::transform(tmp.begin(), tmp.end(), std::back_inserter(results), [](tc::InferResult* ptr) {
46  return std::shared_ptr<tc::InferResult>(ptr);
47  });
48  return results;
49  }
50 } // namespace
51 
52 //based on https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/examples/simple_grpc_async_infer_client.cc
53 //and https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/perf_client.cc
54 
56  : SonicClient(params, debugName, "TritonClient"),
57  batchMode_(TritonBatchMode::Rectangular),
58  manualBatchMode_(false),
59  verbose_(params.getUntrackedParameter<bool>("verbose")),
60  useSharedMemory_(params.getUntrackedParameter<bool>("useSharedMemory")),
61  compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter<std::string>("compression"))) {
62  options_.emplace_back(params.getParameter<std::string>("modelName"));
63  //get appropriate server for this model
65  const auto& server =
66  ts->serverInfo(options_[0].model_name_, params.getUntrackedParameter<std::string>("preferredServer"));
67  serverType_ = server.type;
68  if (verbose_)
69  edm::LogInfo(fullDebugName_) << "Using server: " << server.url;
70  //enforce sync mode for fallback CPU server to avoid contention
71  //todo: could enforce async mode otherwise (unless mode was specified by user?)
74 
75  //connect to the server
77  tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions),
78  "TritonClient(): unable to create inference context");
79 
80  //set options
81  options_[0].model_version_ = params.getParameter<std::string>("modelVersion");
82  //convert seconds to microseconds
83  options_[0].client_timeout_ = params.getUntrackedParameter<unsigned>("timeout") * 1e6;
84 
85  //get fixed parameters from local config
86  inference::ModelConfig localModelConfig;
87  {
88  const std::string& localModelConfigPath(params.getParameter<edm::FileInPath>("modelConfigPath").fullPath());
89  int fileDescriptor = open(localModelConfigPath.c_str(), O_RDONLY);
90  if (fileDescriptor < 0)
91  throw TritonException("LocalFailure")
92  << "TritonClient(): unable to open local model config: " << localModelConfigPath;
93  google::protobuf::io::FileInputStream localModelConfigInput(fileDescriptor);
94  localModelConfigInput.SetCloseOnDelete(true);
95  if (!google::protobuf::TextFormat::Parse(&localModelConfigInput, &localModelConfig))
96  throw TritonException("LocalFailure")
97  << "TritonClient(): unable to parse local model config: " << localModelConfigPath;
98  }
99 
100  //check batch size limitations (after i/o setup)
101  //triton uses max batch size = 0 to denote a model that does not support native batching (using the outer dimension)
102  //but for models that do support batching (native or otherwise), a given event may set batch size 0 to indicate no valid input is present
103  //so set the local max to 1 and keep track of "no outer dim" case
104  maxOuterDim_ = localModelConfig.max_batch_size();
105  noOuterDim_ = maxOuterDim_ == 0;
107  //propagate batch size
108  setBatchSize(1);
109 
110  //compare model checksums to remote config to enforce versioning
111  inference::ModelConfigResponse modelConfigResponse;
112  TRITON_THROW_IF_ERROR(client_->ModelConfig(&modelConfigResponse, options_[0].model_name_, options_[0].model_version_),
113  "TritonClient(): unable to get model config");
114  inference::ModelConfig remoteModelConfig(modelConfigResponse.config());
115 
116  std::map<std::string, std::array<std::string, 2>> checksums;
117  size_t fileCounter = 0;
118  for (const auto& modelConfig : {localModelConfig, remoteModelConfig}) {
119  const auto& agents = modelConfig.model_repository_agents().agents();
120  auto agent = std::find_if(agents.begin(), agents.end(), [](auto const& a) { return a.name() == "checksum"; });
121  if (agent != agents.end()) {
122  const auto& params = agent->parameters();
123  for (const auto& [key, val] : params) {
124  // only check the requested version
125  if (key.compare(0, options_[0].model_version_.size() + 1, options_[0].model_version_ + "/") == 0)
126  checksums[key][fileCounter] = val;
127  }
128  }
129  ++fileCounter;
130  }
131  std::vector<std::string> incorrect;
132  for (const auto& [key, val] : checksums) {
133  if (checksums[key][0] != checksums[key][1])
134  incorrect.push_back(key);
135  }
136  if (!incorrect.empty())
137  throw TritonException("ModelVersioning") << "The following files have incorrect checksums on the remote server: "
138  << triton_utils::printColl(incorrect, ", ");
139 
140  //get model info
141  inference::ModelMetadataResponse modelMetadata;
142  TRITON_THROW_IF_ERROR(client_->ModelMetadata(&modelMetadata, options_[0].model_name_, options_[0].model_version_),
143  "TritonClient(): unable to get model metadata");
144 
145  //get input and output (which know their sizes)
146  const auto& nicInputs = modelMetadata.inputs();
147  const auto& nicOutputs = modelMetadata.outputs();
148 
149  //report all model errors at once
150  std::stringstream msg;
151  std::string msg_str;
152 
153  //currently no use case is foreseen for a model with zero inputs or outputs
154  if (nicInputs.empty())
155  msg << "Model on server appears malformed (zero inputs)\n";
156 
157  if (nicOutputs.empty())
158  msg << "Model on server appears malformed (zero outputs)\n";
159 
160  //stop if errors
161  msg_str = msg.str();
162  if (!msg_str.empty())
163  throw cms::Exception("ModelErrors") << msg_str;
164 
165  //setup input map
166  std::stringstream io_msg;
167  if (verbose_)
168  io_msg << "Model inputs: "
169  << "\n";
170  for (const auto& nicInput : nicInputs) {
171  const auto& iname = nicInput.name();
172  auto [curr_itr, success] = input_.emplace(std::piecewise_construct,
173  std::forward_as_tuple(iname),
174  std::forward_as_tuple(iname, nicInput, this, ts->pid()));
175  auto& curr_input = curr_itr->second;
176  if (verbose_) {
177  io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
178  << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
179  }
180  }
181 
182  //allow selecting only some outputs from server
183  const auto& v_outputs = params.getUntrackedParameter<std::vector<std::string>>("outputs");
184  std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
185 
186  //setup output map
187  if (verbose_)
188  io_msg << "Model outputs: "
189  << "\n";
190  for (const auto& nicOutput : nicOutputs) {
191  const auto& oname = nicOutput.name();
192  if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
193  continue;
194  auto [curr_itr, success] = output_.emplace(std::piecewise_construct,
195  std::forward_as_tuple(oname),
196  std::forward_as_tuple(oname, nicOutput, this, ts->pid()));
197  auto& curr_output = curr_itr->second;
198  if (verbose_) {
199  io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
200  << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
201  }
202  if (!s_outputs.empty())
203  s_outputs.erase(oname);
204  }
205 
206  //check if any requested outputs were not available
207  if (!s_outputs.empty())
208  throw cms::Exception("MissingOutput")
209  << "Some requested outputs were not available on the server: " << triton_utils::printColl(s_outputs);
210 
211  //print model info
212  std::stringstream model_msg;
213  if (verbose_) {
214  model_msg << "Model name: " << options_[0].model_name_ << "\n"
215  << "Model version: " << options_[0].model_version_ << "\n"
216  << "Model max outer dim: " << (noOuterDim_ ? 0 : maxOuterDim_) << "\n";
217  edm::LogInfo(fullDebugName_) << model_msg.str() << io_msg.str();
218  }
219 }
220 
222  //by default: members of this class destroyed before members of base class
223  //in shared memory case, TritonMemResource (member of TritonData) unregisters from client_ in its destructor
224  //but input/output objects are member of base class, so destroyed after client_ (member of this class)
225  //therefore, clear the maps here
226  input_.clear();
227  output_.clear();
228 }
229 
231  unsigned oldBatchSize = batchSize();
233  manualBatchMode_ = true;
234  //this allows calling setBatchSize() and setBatchMode() in either order consistently to change back and forth
235  //includes handling of change from ragged to rectangular if multiple entries already created
236  setBatchSize(oldBatchSize);
237 }
238 
241  manualBatchMode_ = false;
242 }
243 
244 unsigned TritonClient::nEntries() const { return !input_.empty() ? input_.begin()->second.entries_.size() : 0; }
245 
247 
248 bool TritonClient::setBatchSize(unsigned bsize) {
250  if (bsize > maxOuterDim_) {
251  edm::LogWarning(fullDebugName_) << "Requested batch size " << bsize << " exceeds server-specified max batch size "
252  << maxOuterDim_ << ". Batch size will remain as " << outerDim_;
253  return false;
254  } else {
255  outerDim_ = bsize;
256  //take min to allow resizing to 0
258  return true;
259  }
260  } else {
261  resizeEntries(bsize);
262  outerDim_ = 1;
263  return true;
264  }
265 }
266 
268  if (entry > nEntries())
269  //addEntry(entry) extends the vector to size entry+1
270  addEntry(entry - 1);
271  else if (entry < nEntries()) {
272  for (auto& element : input_) {
273  element.second.entries_.resize(entry);
274  }
275  for (auto& element : output_) {
276  element.second.entries_.resize(entry);
277  }
278  }
279 }
280 
282  for (auto& element : input_) {
283  element.second.addEntryImpl(entry);
284  }
285  for (auto& element : output_) {
286  element.second.addEntryImpl(entry);
287  }
288  if (entry > 0) {
290  outerDim_ = 1;
291  }
292 }
293 
295  if (!manualBatchMode_)
297  for (auto& element : input_) {
298  element.second.reset();
299  }
300  for (auto& element : output_) {
301  element.second.reset();
302  }
303 }
304 
305 template <typename F>
307  //caught exceptions will be propagated to edm::WaitingTaskWithArenaHolder
308  CMS_SA_ALLOW try {
309  call();
310  return true;
311  }
312  //TritonExceptions are intended/expected to be recoverable, i.e. retries should be allowed
313  catch (TritonException& e) {
314  e.convertToWarning();
315  finish(false);
316  return false;
317  }
318  //other exceptions are not: execution should stop if they are encountered
319  catch (...) {
320  finish(false, std::current_exception());
321  return false;
322  }
323 }
324 
325 void TritonClient::getResults(const std::vector<std::shared_ptr<tc::InferResult>>& results) {
326  for (unsigned i = 0; i < results.size(); ++i) {
327  const auto& result = results[i];
328  for (auto& [oname, output] : output_) {
329  //set shape here before output becomes const
330  if (output.variableDims()) {
331  std::vector<int64_t> tmp_shape;
332  TRITON_THROW_IF_ERROR(result->Shape(oname, &tmp_shape),
333  "getResults(): unable to get output shape for " + oname);
334  if (!noOuterDim_)
335  tmp_shape.erase(tmp_shape.begin());
336  output.setShape(tmp_shape, i);
337  }
338  //extend lifetime
339  output.setResult(result, i);
340  //compute size after getting all result entries
341  if (i == results.size() - 1)
342  output.computeSizes();
343  }
344  }
345 }
346 
347 //default case for sync and pseudo async
349  //in case there is nothing to process
350  if (batchSize() == 0) {
351  //call getResults on an empty vector
352  std::vector<std::shared_ptr<tc::InferResult>> empty_results;
353  getResults(empty_results);
354  finish(true);
355  return;
356  }
357 
358  //set up input pointers for triton (generalized for multi-request ragged batching case)
359  //one vector<InferInput*> per request
360  unsigned nEntriesVal = nEntries();
361  std::vector<std::vector<triton::client::InferInput*>> inputsTriton(nEntriesVal);
362  for (auto& inputTriton : inputsTriton) {
363  inputTriton.reserve(input_.size());
364  }
365  for (auto& [iname, input] : input_) {
366  for (unsigned i = 0; i < nEntriesVal; ++i) {
367  inputsTriton[i].push_back(input.data(i));
368  }
369  }
370 
371  //set up output pointers similarly
372  std::vector<std::vector<const triton::client::InferRequestedOutput*>> outputsTriton(nEntriesVal);
373  for (auto& outputTriton : outputsTriton) {
374  outputTriton.reserve(output_.size());
375  }
376  for (auto& [oname, output] : output_) {
377  for (unsigned i = 0; i < nEntriesVal; ++i) {
378  outputsTriton[i].push_back(output.data(i));
379  }
380  }
381 
382  //set up shared memory for output
383  auto success = handle_exception([&]() {
384  for (auto& element : output_) {
385  element.second.prepare();
386  }
387  });
388  if (!success)
389  return;
390 
391  // Get the status of the server prior to the request being made.
392  inference::ModelStatistics start_status;
393  success = handle_exception([&]() {
394  if (verbose())
395  start_status = getServerSideStatus();
396  });
397  if (!success)
398  return;
399 
400  if (mode_ == SonicMode::Async) {
401  //non-blocking call
402  success = handle_exception([&]() {
404  client_->AsyncInferMulti(
405  [start_status, this](std::vector<tc::InferResult*> resultsTmp) {
406  //immediately convert to shared_ptr
407  const auto& results = convertToShared(resultsTmp);
408  //check results
409  for (auto ptr : results) {
410  auto success = handle_exception(
411  [&]() { TRITON_THROW_IF_ERROR(ptr->RequestStatus(), "evaluate(): unable to get result(s)"); });
412  if (!success)
413  return;
414  }
415 
416  if (verbose()) {
417  inference::ModelStatistics end_status;
418  auto success = handle_exception([&]() { end_status = getServerSideStatus(); });
419  if (!success)
420  return;
421 
422  const auto& stats = summarizeServerStats(start_status, end_status);
424  }
425 
426  //check result
427  auto success = handle_exception([&]() { getResults(results); });
428  if (!success)
429  return;
430 
431  //finish
432  finish(true);
433  },
434  options_,
435  inputsTriton,
436  outputsTriton,
437  headers_,
439  "evaluate(): unable to launch async run");
440  });
441  if (!success)
442  return;
443  } else {
444  //blocking call
445  std::vector<tc::InferResult*> resultsTmp;
446  success = handle_exception([&]() {
448  client_->InferMulti(&resultsTmp, options_, inputsTriton, outputsTriton, headers_, compressionAlgo_),
449  "evaluate(): unable to run and/or get result");
450  });
451  //immediately convert to shared_ptr
452  const auto& results = convertToShared(resultsTmp);
453  if (!success)
454  return;
455 
456  if (verbose()) {
457  inference::ModelStatistics end_status;
458  success = handle_exception([&]() { end_status = getServerSideStatus(); });
459  if (!success)
460  return;
461 
462  const auto& stats = summarizeServerStats(start_status, end_status);
463  reportServerSideStats(stats);
464  }
465 
466  success = handle_exception([&]() { getResults(results); });
467  if (!success)
468  return;
469 
470  finish(true);
471  }
472 }
473 
475  std::stringstream msg;
476 
477  // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc
478  const uint64_t count = stats.success_count_;
479  msg << " Inference count: " << stats.inference_count_ << "\n";
480  msg << " Execution count: " << stats.execution_count_ << "\n";
481  msg << " Successful request count: " << count << "\n";
482 
483  if (count > 0) {
484  auto get_avg_us = [count](uint64_t tval) {
485  constexpr uint64_t us_to_ns = 1000;
486  return tval / us_to_ns / count;
487  };
488 
489  const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
490  const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
491  const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
492  const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
493  const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
494  const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
495  const uint64_t overhead =
496  (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
497 
498  msg << " Avg request latency: " << cumm_avg_us << " usec"
499  << "\n"
500  << " (overhead " << overhead << " usec + "
501  << "queue " << queue_avg_us << " usec + "
502  << "compute input " << compute_input_avg_us << " usec + "
503  << "compute infer " << compute_infer_avg_us << " usec + "
504  << "compute output " << compute_output_avg_us << " usec)" << std::endl;
505  }
506 
507  if (!debugName_.empty())
508  edm::LogInfo(fullDebugName_) << msg.str();
509 }
510 
511 TritonClient::ServerSideStats TritonClient::summarizeServerStats(const inference::ModelStatistics& start_status,
512  const inference::ModelStatistics& end_status) const {
513  TritonClient::ServerSideStats server_stats;
514 
515  server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
516  server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
517  server_stats.success_count_ =
518  end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
519  server_stats.cumm_time_ns_ =
520  end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
521  server_stats.queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
522  server_stats.compute_input_time_ns_ =
523  end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
524  server_stats.compute_infer_time_ns_ =
525  end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
526  server_stats.compute_output_time_ns_ =
527  end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
528 
529  return server_stats;
530 }
531 
532 inference::ModelStatistics TritonClient::getServerSideStatus() const {
533  if (verbose_) {
534  inference::ModelStatisticsResponse resp;
535  TRITON_THROW_IF_ERROR(client_->ModelInferenceStatistics(&resp, options_[0].model_name_, options_[0].model_version_),
536  "getServerSideStatus(): unable to get model statistics");
537  return *(resp.model_stats().begin());
538  }
539  return inference::ModelStatistics{};
540 }
541 
542 //for fillDescriptions
544  edm::ParameterSetDescription descClient;
545  fillBasePSetDescription(descClient);
546  descClient.add<std::string>("modelName");
547  descClient.add<std::string>("modelVersion", "");
548  descClient.add<edm::FileInPath>("modelConfigPath");
549  //server parameters should not affect the physics results
550  descClient.addUntracked<std::string>("preferredServer", "");
551  descClient.addUntracked<unsigned>("timeout");
552  descClient.addUntracked<bool>("useSharedMemory", true);
553  descClient.addUntracked<std::string>("compression", "");
554  descClient.addUntracked<std::vector<std::string>>("outputs", {});
555  iDesc.add<edm::ParameterSetDescription>("Client", descClient);
556 }
bool verbose() const
Definition: TritonClient.h:43
void getResults(const std::vector< std::shared_ptr< triton::client::InferResult >> &results)
const std::string & pid() const
#define CMS_SA_ALLOW
unsigned maxOuterDim_
Definition: TritonClient.h:72
bool setBatchSize(unsigned bsize)
bool noOuterDim_
Definition: TritonClient.h:74
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
~TritonClient() override
std::string fullPath() const
Definition: FileInPath.cc:161
void addEntry(unsigned entry)
#define TRITON_THROW_IF_ERROR(X, MSG)
Definition: triton_utils.h:75
bool verbose
bool manualBatchMode_
Definition: TritonClient.h:77
TritonBatchMode batchMode() const
Definition: TritonClient.h:42
void setMode(SonicMode mode)
std::unique_ptr< triton::client::InferenceServerGrpcClient > client_
Definition: TritonClient.h:84
TritonClient(const edm::ParameterSet &params, const std::string &debugName)
Definition: TritonClient.cc:55
std::string debugName_
void finish(bool success, std::exception_ptr eptr=std::exception_ptr{})
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
TritonBatchMode
Definition: TritonClient.h:19
void resetBatchMode()
TritonServerType serverType_
Definition: TritonClient.h:80
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e< void, edm::EventID const &, edm::Timestamp const & > We also list in braces which AR_WATCH_USING_METHOD_ is used for those or
Definition: Activities.doc:12
bool handle_exception(F &&call)
key
prepare the HTCondor submission files and eventually submit them
grpc_compression_algorithm compressionAlgo_
Definition: TritonClient.h:81
ParameterDescriptionBase * add(U const &iLabel, T const &value)
static void fillBasePSetDescription(edm::ParameterSetDescription &desc, bool allowRetry=true)
void resizeEntries(unsigned entry)
inference::ModelStatistics getServerSideStatus() const
Log< level::Info, false > LogInfo
triton::client::Headers headers_
Definition: TritonClient.h:82
unsigned nEntries() const
unsigned long long uint64_t
Definition: Time.h:13
tuple msg
Definition: mps_check.py:286
unsigned outerDim_
Definition: TritonClient.h:73
unsigned batchSize() const
void evaluate() override
void setBatchMode(TritonBatchMode batchMode)
Server serverInfo(const std::string &model, const std::string &preferred="") const
double a
Definition: hdecay.h:121
void reportServerSideStats(const ServerSideStats &stats) const
void reset() override
std::string fullDebugName_
results
Definition: mysort.py:8
Definition: output.py:1
static void fillPSetDescription(edm::ParameterSetDescription &iDesc)
Log< level::Warning, false > LogWarning
static uInt32 F(BLOWFISH_CTX *ctx, uInt32 x)
Definition: blowfish.cc:163
TritonBatchMode batchMode_
Definition: TritonClient.h:76
std::string printColl(const C &coll, const std::string &delim=", ")
Definition: triton_utils.cc:9
tmp
align.sh
Definition: createJobs.py:716
std::vector< triton::client::InferOptions > options_
Definition: TritonClient.h:86
if(threadIdxLocalY==0 &&threadIdxLocalX==0)
unsigned transform(const HcalDetId &id, unsigned transformCode)