CMS 3D CMS Logo

TritonClient.cc
Go to the documentation of this file.
9 
10 #include "grpc_client.h"
11 #include "grpc_service.pb.h"
12 
13 #include <string>
14 #include <cmath>
15 #include <exception>
16 #include <sstream>
17 #include <utility>
18 #include <tuple>
19 
20 namespace tc = triton::client;
21 
22 namespace {
23  grpc_compression_algorithm getCompressionAlgo(const std::string& name) {
24  if (name.empty() or name.compare("none") == 0)
25  return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
26  else if (name.compare("deflate") == 0)
27  return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
28  else if (name.compare("gzip") == 0)
29  return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
30  else
31  throw cms::Exception("GrpcCompression")
32  << "Unknown compression algorithm requested: " << name << " (choices: none, deflate, gzip)";
33  }
34 
35  std::vector<std::shared_ptr<tc::InferResult>> convertToShared(const std::vector<tc::InferResult*>& tmp) {
36  std::vector<std::shared_ptr<tc::InferResult>> results;
37  results.reserve(tmp.size());
38  std::transform(tmp.begin(), tmp.end(), std::back_inserter(results), [](tc::InferResult* ptr) {
39  return std::shared_ptr<tc::InferResult>(ptr);
40  });
41  return results;
42  }
43 } // namespace
44 
45 //based on https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/examples/simple_grpc_async_infer_client.cc
46 //and https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/perf_client.cc
47 
49  : SonicClient(params, debugName, "TritonClient"),
50  batchMode_(TritonBatchMode::Rectangular),
51  manualBatchMode_(false),
52  verbose_(params.getUntrackedParameter<bool>("verbose")),
53  useSharedMemory_(params.getUntrackedParameter<bool>("useSharedMemory")),
54  compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter<std::string>("compression"))) {
55  options_.emplace_back(params.getParameter<std::string>("modelName"));
56  //get appropriate server for this model
58  const auto& server =
59  ts->serverInfo(options_[0].model_name_, params.getUntrackedParameter<std::string>("preferredServer"));
60  serverType_ = server.type;
61  if (verbose_)
62  edm::LogInfo(fullDebugName_) << "Using server: " << server.url;
63  //enforce sync mode for fallback CPU server to avoid contention
64  //todo: could enforce async mode otherwise (unless mode was specified by user?)
67 
68  //connect to the server
70  tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions),
71  "TritonClient(): unable to create inference context");
72 
73  //set options
74  options_[0].model_version_ = params.getParameter<std::string>("modelVersion");
75  //convert seconds to microseconds
76  options_[0].client_timeout_ = params.getUntrackedParameter<unsigned>("timeout") * 1e6;
77 
78  //config needed for batch size
79  inference::ModelConfigResponse modelConfigResponse;
80  TRITON_THROW_IF_ERROR(client_->ModelConfig(&modelConfigResponse, options_[0].model_name_, options_[0].model_version_),
81  "TritonClient(): unable to get model config");
82  inference::ModelConfig modelConfig(modelConfigResponse.config());
83 
84  //check batch size limitations (after i/o setup)
85  //triton uses max batch size = 0 to denote a model that does not support native batching (using the outer dimension)
86  //but for models that do support batching (native or otherwise), a given event may set batch size 0 to indicate no valid input is present
87  //so set the local max to 1 and keep track of "no outer dim" case
88  maxOuterDim_ = modelConfig.max_batch_size();
91  //propagate batch size
92  setBatchSize(1);
93 
94  //get model info
95  inference::ModelMetadataResponse modelMetadata;
96  TRITON_THROW_IF_ERROR(client_->ModelMetadata(&modelMetadata, options_[0].model_name_, options_[0].model_version_),
97  "TritonClient(): unable to get model metadata");
98 
99  //get input and output (which know their sizes)
100  const auto& nicInputs = modelMetadata.inputs();
101  const auto& nicOutputs = modelMetadata.outputs();
102 
103  //report all model errors at once
104  std::stringstream msg;
105  std::string msg_str;
106 
107  //currently no use case is foreseen for a model with zero inputs or outputs
108  if (nicInputs.empty())
109  msg << "Model on server appears malformed (zero inputs)\n";
110 
111  if (nicOutputs.empty())
112  msg << "Model on server appears malformed (zero outputs)\n";
113 
114  //stop if errors
115  msg_str = msg.str();
116  if (!msg_str.empty())
117  throw cms::Exception("ModelErrors") << msg_str;
118 
119  //setup input map
120  std::stringstream io_msg;
121  if (verbose_)
122  io_msg << "Model inputs: "
123  << "\n";
124  for (const auto& nicInput : nicInputs) {
125  const auto& iname = nicInput.name();
126  auto [curr_itr, success] = input_.emplace(std::piecewise_construct,
127  std::forward_as_tuple(iname),
128  std::forward_as_tuple(iname, nicInput, this, ts->pid()));
129  auto& curr_input = curr_itr->second;
130  if (verbose_) {
131  io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
132  << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
133  }
134  }
135 
136  //allow selecting only some outputs from server
137  const auto& v_outputs = params.getUntrackedParameter<std::vector<std::string>>("outputs");
138  std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
139 
140  //setup output map
141  if (verbose_)
142  io_msg << "Model outputs: "
143  << "\n";
144  for (const auto& nicOutput : nicOutputs) {
145  const auto& oname = nicOutput.name();
146  if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
147  continue;
148  auto [curr_itr, success] = output_.emplace(std::piecewise_construct,
149  std::forward_as_tuple(oname),
150  std::forward_as_tuple(oname, nicOutput, this, ts->pid()));
151  auto& curr_output = curr_itr->second;
152  if (verbose_) {
153  io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
154  << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
155  }
156  if (!s_outputs.empty())
157  s_outputs.erase(oname);
158  }
159 
160  //check if any requested outputs were not available
161  if (!s_outputs.empty())
162  throw cms::Exception("MissingOutput")
163  << "Some requested outputs were not available on the server: " << triton_utils::printColl(s_outputs);
164 
165  //print model info
166  std::stringstream model_msg;
167  if (verbose_) {
168  model_msg << "Model name: " << options_[0].model_name_ << "\n"
169  << "Model version: " << options_[0].model_version_ << "\n"
170  << "Model max outer dim: " << (noOuterDim_ ? 0 : maxOuterDim_) << "\n";
171  edm::LogInfo(fullDebugName_) << model_msg.str() << io_msg.str();
172  }
173 }
174 
176  //by default: members of this class destroyed before members of base class
177  //in shared memory case, TritonMemResource (member of TritonData) unregisters from client_ in its destructor
178  //but input/output objects are member of base class, so destroyed after client_ (member of this class)
179  //therefore, clear the maps here
180  input_.clear();
181  output_.clear();
182 }
183 
185  unsigned oldBatchSize = batchSize();
187  manualBatchMode_ = true;
188  //this allows calling setBatchSize() and setBatchMode() in either order consistently to change back and forth
189  //includes handling of change from ragged to rectangular if multiple entries already created
190  setBatchSize(oldBatchSize);
191 }
192 
195  manualBatchMode_ = false;
196 }
197 
198 unsigned TritonClient::nEntries() const { return !input_.empty() ? input_.begin()->second.entries_.size() : 0; }
199 
201 
202 bool TritonClient::setBatchSize(unsigned bsize) {
204  if (bsize > maxOuterDim_) {
205  edm::LogWarning(fullDebugName_) << "Requested batch size " << bsize << " exceeds server-specified max batch size "
206  << maxOuterDim_ << ". Batch size will remain as " << outerDim_;
207  return false;
208  } else {
209  outerDim_ = bsize;
210  //take min to allow resizing to 0
212  return true;
213  }
214  } else {
215  resizeEntries(bsize);
216  outerDim_ = 1;
217  return true;
218  }
219 }
220 
222  if (entry > nEntries())
223  //addEntry(entry) extends the vector to size entry+1
224  addEntry(entry - 1);
225  else if (entry < nEntries()) {
226  for (auto& element : input_) {
227  element.second.entries_.resize(entry);
228  }
229  for (auto& element : output_) {
230  element.second.entries_.resize(entry);
231  }
232  }
233 }
234 
236  for (auto& element : input_) {
237  element.second.addEntryImpl(entry);
238  }
239  for (auto& element : output_) {
240  element.second.addEntryImpl(entry);
241  }
242  if (entry > 0) {
244  outerDim_ = 1;
245  }
246 }
247 
249  if (!manualBatchMode_)
251  for (auto& element : input_) {
252  element.second.reset();
253  }
254  for (auto& element : output_) {
255  element.second.reset();
256  }
257 }
258 
259 template <typename F>
261  //caught exceptions will be propagated to edm::WaitingTaskWithArenaHolder
262  CMS_SA_ALLOW try {
263  call();
264  return true;
265  }
266  //TritonExceptions are intended/expected to be recoverable, i.e. retries should be allowed
267  catch (TritonException& e) {
268  e.convertToWarning();
269  finish(false);
270  return false;
271  }
272  //other exceptions are not: execution should stop if they are encountered
273  catch (...) {
274  finish(false, std::current_exception());
275  return false;
276  }
277 }
278 
279 void TritonClient::getResults(const std::vector<std::shared_ptr<tc::InferResult>>& results) {
280  for (unsigned i = 0; i < results.size(); ++i) {
281  const auto& result = results[i];
282  for (auto& [oname, output] : output_) {
283  //set shape here before output becomes const
284  if (output.variableDims()) {
285  std::vector<int64_t> tmp_shape;
286  TRITON_THROW_IF_ERROR(result->Shape(oname, &tmp_shape),
287  "getResults(): unable to get output shape for " + oname);
288  if (!noOuterDim_)
289  tmp_shape.erase(tmp_shape.begin());
290  output.setShape(tmp_shape, i);
291  }
292  //extend lifetime
293  output.setResult(result, i);
294  //compute size after getting all result entries
295  if (i == results.size() - 1)
296  output.computeSizes();
297  }
298  }
299 }
300 
301 //default case for sync and pseudo async
303  //in case there is nothing to process
304  if (batchSize() == 0) {
305  //call getResults on an empty vector
306  std::vector<std::shared_ptr<tc::InferResult>> empty_results;
307  getResults(empty_results);
308  finish(true);
309  return;
310  }
311 
312  //set up input pointers for triton (generalized for multi-request ragged batching case)
313  //one vector<InferInput*> per request
314  unsigned nEntriesVal = nEntries();
315  std::vector<std::vector<triton::client::InferInput*>> inputsTriton(nEntriesVal);
316  for (auto& inputTriton : inputsTriton) {
317  inputTriton.reserve(input_.size());
318  }
319  for (auto& [iname, input] : input_) {
320  for (unsigned i = 0; i < nEntriesVal; ++i) {
321  inputsTriton[i].push_back(input.data(i));
322  }
323  }
324 
325  //set up output pointers similarly
326  std::vector<std::vector<const triton::client::InferRequestedOutput*>> outputsTriton(nEntriesVal);
327  for (auto& outputTriton : outputsTriton) {
328  outputTriton.reserve(output_.size());
329  }
330  for (auto& [oname, output] : output_) {
331  for (unsigned i = 0; i < nEntriesVal; ++i) {
332  outputsTriton[i].push_back(output.data(i));
333  }
334  }
335 
336  //set up shared memory for output
337  auto success = handle_exception([&]() {
338  for (auto& element : output_) {
339  element.second.prepare();
340  }
341  });
342  if (!success)
343  return;
344 
345  // Get the status of the server prior to the request being made.
346  inference::ModelStatistics start_status;
347  success = handle_exception([&]() {
348  if (verbose())
349  start_status = getServerSideStatus();
350  });
351  if (!success)
352  return;
353 
354  if (mode_ == SonicMode::Async) {
355  //non-blocking call
356  success = handle_exception([&]() {
358  client_->AsyncInferMulti(
359  [start_status, this](std::vector<tc::InferResult*> resultsTmp) {
360  //immediately convert to shared_ptr
361  const auto& results = convertToShared(resultsTmp);
362  //check results
363  for (auto ptr : results) {
364  auto success = handle_exception(
365  [&]() { TRITON_THROW_IF_ERROR(ptr->RequestStatus(), "evaluate(): unable to get result(s)"); });
366  if (!success)
367  return;
368  }
369 
370  if (verbose()) {
371  inference::ModelStatistics end_status;
372  auto success = handle_exception([&]() { end_status = getServerSideStatus(); });
373  if (!success)
374  return;
375 
376  const auto& stats = summarizeServerStats(start_status, end_status);
378  }
379 
380  //check result
381  auto success = handle_exception([&]() { getResults(results); });
382  if (!success)
383  return;
384 
385  //finish
386  finish(true);
387  },
388  options_,
389  inputsTriton,
390  outputsTriton,
391  headers_,
393  "evaluate(): unable to launch async run");
394  });
395  if (!success)
396  return;
397  } else {
398  //blocking call
399  std::vector<tc::InferResult*> resultsTmp;
400  success = handle_exception([&]() {
402  client_->InferMulti(&resultsTmp, options_, inputsTriton, outputsTriton, headers_, compressionAlgo_),
403  "evaluate(): unable to run and/or get result");
404  });
405  //immediately convert to shared_ptr
406  const auto& results = convertToShared(resultsTmp);
407  if (!success)
408  return;
409 
410  if (verbose()) {
411  inference::ModelStatistics end_status;
412  success = handle_exception([&]() { end_status = getServerSideStatus(); });
413  if (!success)
414  return;
415 
416  const auto& stats = summarizeServerStats(start_status, end_status);
417  reportServerSideStats(stats);
418  }
419 
420  success = handle_exception([&]() { getResults(results); });
421  if (!success)
422  return;
423 
424  finish(true);
425  }
426 }
427 
429  std::stringstream msg;
430 
431  // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc
432  const uint64_t count = stats.success_count_;
433  msg << " Inference count: " << stats.inference_count_ << "\n";
434  msg << " Execution count: " << stats.execution_count_ << "\n";
435  msg << " Successful request count: " << count << "\n";
436 
437  if (count > 0) {
438  auto get_avg_us = [count](uint64_t tval) {
439  constexpr uint64_t us_to_ns = 1000;
440  return tval / us_to_ns / count;
441  };
442 
443  const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
444  const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
445  const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
446  const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
447  const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
448  const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
449  const uint64_t overhead =
450  (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
451 
452  msg << " Avg request latency: " << cumm_avg_us << " usec"
453  << "\n"
454  << " (overhead " << overhead << " usec + "
455  << "queue " << queue_avg_us << " usec + "
456  << "compute input " << compute_input_avg_us << " usec + "
457  << "compute infer " << compute_infer_avg_us << " usec + "
458  << "compute output " << compute_output_avg_us << " usec)" << std::endl;
459  }
460 
461  if (!debugName_.empty())
462  edm::LogInfo(fullDebugName_) << msg.str();
463 }
464 
465 TritonClient::ServerSideStats TritonClient::summarizeServerStats(const inference::ModelStatistics& start_status,
466  const inference::ModelStatistics& end_status) const {
467  TritonClient::ServerSideStats server_stats;
468 
469  server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
470  server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
471  server_stats.success_count_ =
472  end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
473  server_stats.cumm_time_ns_ =
474  end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
475  server_stats.queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
476  server_stats.compute_input_time_ns_ =
477  end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
478  server_stats.compute_infer_time_ns_ =
479  end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
480  server_stats.compute_output_time_ns_ =
481  end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
482 
483  return server_stats;
484 }
485 
486 inference::ModelStatistics TritonClient::getServerSideStatus() const {
487  if (verbose_) {
488  inference::ModelStatisticsResponse resp;
489  TRITON_THROW_IF_ERROR(client_->ModelInferenceStatistics(&resp, options_[0].model_name_, options_[0].model_version_),
490  "getServerSideStatus(): unable to get model statistics");
491  return *(resp.model_stats().begin());
492  }
493  return inference::ModelStatistics{};
494 }
495 
496 //for fillDescriptions
498  edm::ParameterSetDescription descClient;
499  fillBasePSetDescription(descClient);
500  descClient.add<std::string>("modelName");
501  descClient.add<std::string>("modelVersion", "");
502  descClient.add<edm::FileInPath>("modelConfigPath");
503  //server parameters should not affect the physics results
504  descClient.addUntracked<std::string>("preferredServer", "");
505  descClient.addUntracked<unsigned>("timeout");
506  descClient.addUntracked<bool>("useSharedMemory", true);
507  descClient.addUntracked<std::string>("compression", "");
508  descClient.addUntracked<std::vector<std::string>>("outputs", {});
509  iDesc.add<edm::ParameterSetDescription>("Client", descClient);
510 }
bool verbose() const
Definition: TritonClient.h:43
void getResults(const std::vector< std::shared_ptr< triton::client::InferResult >> &results)
const std::string & pid() const
#define CMS_SA_ALLOW
unsigned maxOuterDim_
Definition: TritonClient.h:72
bool setBatchSize(unsigned bsize)
bool noOuterDim_
Definition: TritonClient.h:74
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
~TritonClient() override
void addEntry(unsigned entry)
#define TRITON_THROW_IF_ERROR(X, MSG)
Definition: triton_utils.h:75
bool verbose
bool manualBatchMode_
Definition: TritonClient.h:77
TritonBatchMode batchMode() const
Definition: TritonClient.h:42
void setMode(SonicMode mode)
std::unique_ptr< triton::client::InferenceServerGrpcClient > client_
Definition: TritonClient.h:84
TritonClient(const edm::ParameterSet &params, const std::string &debugName)
Definition: TritonClient.cc:48
std::string debugName_
void finish(bool success, std::exception_ptr eptr=std::exception_ptr{})
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
TritonBatchMode
Definition: TritonClient.h:19
void resetBatchMode()
TritonServerType serverType_
Definition: TritonClient.h:80
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e< void, edm::EventID const &, edm::Timestamp const & > We also list in braces which AR_WATCH_USING_METHOD_ is used for those or
Definition: Activities.doc:12
bool handle_exception(F &&call)
grpc_compression_algorithm compressionAlgo_
Definition: TritonClient.h:81
ParameterDescriptionBase * add(U const &iLabel, T const &value)
static void fillBasePSetDescription(edm::ParameterSetDescription &desc, bool allowRetry=true)
void resizeEntries(unsigned entry)
inference::ModelStatistics getServerSideStatus() const
Log< level::Info, false > LogInfo
triton::client::Headers headers_
Definition: TritonClient.h:82
unsigned nEntries() const
unsigned long long uint64_t
Definition: Time.h:13
tuple msg
Definition: mps_check.py:286
unsigned outerDim_
Definition: TritonClient.h:73
unsigned batchSize() const
void evaluate() override
void setBatchMode(TritonBatchMode batchMode)
Server serverInfo(const std::string &model, const std::string &preferred="") const
void reportServerSideStats(const ServerSideStats &stats) const
void reset() override
std::string fullDebugName_
results
Definition: mysort.py:8
static void fillPSetDescription(edm::ParameterSetDescription &iDesc)
Log< level::Warning, false > LogWarning
static uInt32 F(BLOWFISH_CTX *ctx, uInt32 x)
Definition: blowfish.cc:163
TritonBatchMode batchMode_
Definition: TritonClient.h:76
std::string printColl(const C &coll, const std::string &delim=", ")
Definition: triton_utils.cc:9
tmp
align.sh
Definition: createJobs.py:716
std::vector< triton::client::InferOptions > options_
Definition: TritonClient.h:86
unsigned transform(const HcalDetId &id, unsigned transformCode)