CMS 3D CMS Logo

TritonClient.cc
Go to the documentation of this file.
9 
10 #include "grpc_client.h"
11 #include "grpc_service.pb.h"
12 
13 #include <string>
14 #include <cmath>
15 #include <exception>
16 #include <sstream>
17 #include <utility>
18 #include <tuple>
19 
20 namespace tc = triton::client;
21 
22 namespace {
23  grpc_compression_algorithm getCompressionAlgo(const std::string& name) {
24  if (name.empty() or name.compare("none") == 0)
25  return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
26  else if (name.compare("deflate") == 0)
27  return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
28  else if (name.compare("gzip") == 0)
29  return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
30  else
31  throw cms::Exception("GrpcCompression")
32  << "Unknown compression algorithm requested: " << name << " (choices: none, deflate, gzip)";
33  }
34 } // namespace
35 
36 //based on https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/examples/simple_grpc_async_infer_client.cc
37 //and https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/perf_client.cc
38 
40  : SonicClient(params, debugName, "TritonClient"),
41  verbose_(params.getUntrackedParameter<bool>("verbose")),
42  useSharedMemory_(params.getUntrackedParameter<bool>("useSharedMemory")),
43  compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter<std::string>("compression"))),
44  options_(params.getParameter<std::string>("modelName")) {
45  //get appropriate server for this model
47  const auto& server =
48  ts->serverInfo(options_.model_name_, params.getUntrackedParameter<std::string>("preferredServer"));
49  serverType_ = server.type;
50  if (verbose_)
51  edm::LogInfo(fullDebugName_) << "Using server: " << server.url;
52  //enforce sync mode for fallback CPU server to avoid contention
53  //todo: could enforce async mode otherwise (unless mode was specified by user?)
56 
57  //connect to the server
59  tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions),
60  "TritonClient(): unable to create inference context");
61 
62  //set options
63  options_.model_version_ = params.getParameter<std::string>("modelVersion");
64  //convert seconds to microseconds
65  options_.client_timeout_ = params.getUntrackedParameter<unsigned>("timeout") * 1e6;
66 
67  //config needed for batch size
68  inference::ModelConfigResponse modelConfigResponse;
69  TRITON_THROW_IF_ERROR(client_->ModelConfig(&modelConfigResponse, options_.model_name_, options_.model_version_),
70  "TritonClient(): unable to get model config");
71  inference::ModelConfig modelConfig(modelConfigResponse.config());
72 
73  //check batch size limitations (after i/o setup)
74  //triton uses max batch size = 0 to denote a model that does not support batching
75  //but for models that do support batching, a given event may set batch size 0 to indicate no valid input is present
76  //so set the local max to 1 and keep track of "no batch" case
77  maxBatchSize_ = modelConfig.max_batch_size();
78  noBatch_ = maxBatchSize_ == 0;
80 
81  //get model info
82  inference::ModelMetadataResponse modelMetadata;
83  TRITON_THROW_IF_ERROR(client_->ModelMetadata(&modelMetadata, options_.model_name_, options_.model_version_),
84  "TritonClient(): unable to get model metadata");
85 
86  //get input and output (which know their sizes)
87  const auto& nicInputs = modelMetadata.inputs();
88  const auto& nicOutputs = modelMetadata.outputs();
89 
90  //report all model errors at once
91  std::stringstream msg;
92  std::string msg_str;
93 
94  //currently no use case is foreseen for a model with zero inputs or outputs
95  if (nicInputs.empty())
96  msg << "Model on server appears malformed (zero inputs)\n";
97 
98  if (nicOutputs.empty())
99  msg << "Model on server appears malformed (zero outputs)\n";
100 
101  //stop if errors
102  msg_str = msg.str();
103  if (!msg_str.empty())
104  throw cms::Exception("ModelErrors") << msg_str;
105 
106  //setup input map
107  std::stringstream io_msg;
108  if (verbose_)
109  io_msg << "Model inputs: "
110  << "\n";
111  inputsTriton_.reserve(nicInputs.size());
112  for (const auto& nicInput : nicInputs) {
113  const auto& iname = nicInput.name();
114  auto [curr_itr, success] = input_.emplace(std::piecewise_construct,
115  std::forward_as_tuple(iname),
116  std::forward_as_tuple(iname, nicInput, this, ts->pid()));
117  auto& curr_input = curr_itr->second;
118  inputsTriton_.push_back(curr_input.data());
119  if (verbose_) {
120  io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
121  << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
122  }
123  }
124 
125  //allow selecting only some outputs from server
126  const auto& v_outputs = params.getUntrackedParameter<std::vector<std::string>>("outputs");
127  std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
128 
129  //setup output map
130  if (verbose_)
131  io_msg << "Model outputs: "
132  << "\n";
133  outputsTriton_.reserve(nicOutputs.size());
134  for (const auto& nicOutput : nicOutputs) {
135  const auto& oname = nicOutput.name();
136  if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
137  continue;
138  auto [curr_itr, success] = output_.emplace(std::piecewise_construct,
139  std::forward_as_tuple(oname),
140  std::forward_as_tuple(oname, nicOutput, this, ts->pid()));
141  auto& curr_output = curr_itr->second;
142  outputsTriton_.push_back(curr_output.data());
143  if (verbose_) {
144  io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
145  << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
146  }
147  if (!s_outputs.empty())
148  s_outputs.erase(oname);
149  }
150 
151  //check if any requested outputs were not available
152  if (!s_outputs.empty())
153  throw cms::Exception("MissingOutput")
154  << "Some requested outputs were not available on the server: " << triton_utils::printColl(s_outputs);
155 
156  //propagate batch size to inputs and outputs
157  setBatchSize(1);
158 
159  //print model info
160  std::stringstream model_msg;
161  if (verbose_) {
162  model_msg << "Model name: " << options_.model_name_ << "\n"
163  << "Model version: " << options_.model_version_ << "\n"
164  << "Model max batch size: " << (noBatch_ ? 0 : maxBatchSize_) << "\n";
165  edm::LogInfo(fullDebugName_) << model_msg.str() << io_msg.str();
166  }
167 }
168 
170  //by default: members of this class destroyed before members of base class
171  //in shared memory case, TritonMemResource (member of TritonData) unregisters from client_ in its destructor
172  //but input/output objects are member of base class, so destroyed after client_ (member of this class)
173  //therefore, clear the maps here
174  input_.clear();
175  output_.clear();
176 }
177 
178 bool TritonClient::setBatchSize(unsigned bsize) {
179  if (bsize > maxBatchSize_) {
180  edm::LogWarning(fullDebugName_) << "Requested batch size " << bsize << " exceeds server-specified max batch size "
181  << maxBatchSize_ << ". Batch size will remain as" << batchSize_;
182  return false;
183  } else {
184  batchSize_ = bsize;
185  //set for input and output
186  for (auto& element : input_) {
187  element.second.setBatchSize(bsize);
188  }
189  for (auto& element : output_) {
190  element.second.setBatchSize(bsize);
191  }
192  return true;
193  }
194 }
195 
197  for (auto& element : input_) {
198  element.second.reset();
199  }
200  for (auto& element : output_) {
201  element.second.reset();
202  }
203 }
204 
205 template <typename F>
207  //caught exceptions will be propagated to edm::WaitingTaskWithArenaHolder
208  CMS_SA_ALLOW try {
209  call();
210  return true;
211  }
212  //TritonExceptions are intended/expected to be recoverable, i.e. retries should be allowed
213  catch (TritonException& e) {
214  e.convertToWarning();
215  finish(false);
216  return false;
217  }
218  //other exceptions are not: execution should stop if they are encountered
219  catch (...) {
220  finish(false, std::current_exception());
221  return false;
222  }
223 }
224 
225 void TritonClient::getResults(std::shared_ptr<tc::InferResult> results) {
226  for (auto& [oname, output] : output_) {
227  //set shape here before output becomes const
228  if (output.variableDims()) {
229  std::vector<int64_t> tmp_shape;
230  TRITON_THROW_IF_ERROR(results->Shape(oname, &tmp_shape), "getResults(): unable to get output shape for " + oname);
231  if (!noBatch_)
232  tmp_shape.erase(tmp_shape.begin());
233  output.setShape(tmp_shape);
234  output.computeSizes();
235  }
236  //extend lifetime
237  output.setResult(results);
238  }
239 }
240 
241 //default case for sync and pseudo async
243  //in case there is nothing to process
244  if (batchSize_ == 0) {
245  finish(true);
246  return;
247  }
248 
249  //set up shared memory for output
250  auto success = handle_exception([&]() {
251  for (auto& element : output_) {
252  element.second.prepare();
253  }
254  });
255  if (!success)
256  return;
257 
258  // Get the status of the server prior to the request being made.
259  inference::ModelStatistics start_status;
260  success = handle_exception([&]() {
261  if (verbose())
262  start_status = getServerSideStatus();
263  });
264  if (!success)
265  return;
266 
267  if (mode_ == SonicMode::Async) {
268  //non-blocking call
269  success = handle_exception([&]() {
271  client_->AsyncInfer(
272  [start_status, this](tc::InferResult* results) {
273  //get results
274  std::shared_ptr<tc::InferResult> results_ptr(results);
275  auto success = handle_exception(
276  [&]() { TRITON_THROW_IF_ERROR(results_ptr->RequestStatus(), "evaluate(): unable to get result"); });
277  if (!success)
278  return;
279 
280  if (verbose()) {
281  inference::ModelStatistics end_status;
282  success = handle_exception([&]() { end_status = getServerSideStatus(); });
283  if (!success)
284  return;
285 
286  const auto& stats = summarizeServerStats(start_status, end_status);
288  }
289 
290  //check result
291  success = handle_exception([&]() { getResults(results_ptr); });
292  if (!success)
293  return;
294 
295  //finish
296  finish(true);
297  },
298  options_,
301  headers_,
303  "evaluate(): unable to launch async run");
304  });
305  if (!success)
306  return;
307  } else {
308  //blocking call
309  tc::InferResult* results;
310  success = handle_exception([&]() {
312  client_->Infer(&results, options_, inputsTriton_, outputsTriton_, headers_, compressionAlgo_),
313  "evaluate(): unable to run and/or get result");
314  });
315  if (!success)
316  return;
317 
318  if (verbose()) {
319  inference::ModelStatistics end_status;
320  success = handle_exception([&]() { end_status = getServerSideStatus(); });
321  if (!success)
322  return;
323 
324  const auto& stats = summarizeServerStats(start_status, end_status);
325  reportServerSideStats(stats);
326  }
327 
328  std::shared_ptr<tc::InferResult> results_ptr(results);
329  success = handle_exception([&]() { getResults(results_ptr); });
330  if (!success)
331  return;
332 
333  finish(true);
334  }
335 }
336 
338  std::stringstream msg;
339 
340  // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc
341  const uint64_t count = stats.success_count_;
342  msg << " Inference count: " << stats.inference_count_ << "\n";
343  msg << " Execution count: " << stats.execution_count_ << "\n";
344  msg << " Successful request count: " << count << "\n";
345 
346  if (count > 0) {
347  auto get_avg_us = [count](uint64_t tval) {
348  constexpr uint64_t us_to_ns = 1000;
349  return tval / us_to_ns / count;
350  };
351 
352  const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
353  const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
354  const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
355  const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
356  const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
357  const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
358  const uint64_t overhead =
359  (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
360 
361  msg << " Avg request latency: " << cumm_avg_us << " usec"
362  << "\n"
363  << " (overhead " << overhead << " usec + "
364  << "queue " << queue_avg_us << " usec + "
365  << "compute input " << compute_input_avg_us << " usec + "
366  << "compute infer " << compute_infer_avg_us << " usec + "
367  << "compute output " << compute_output_avg_us << " usec)" << std::endl;
368  }
369 
370  if (!debugName_.empty())
371  edm::LogInfo(fullDebugName_) << msg.str();
372 }
373 
374 TritonClient::ServerSideStats TritonClient::summarizeServerStats(const inference::ModelStatistics& start_status,
375  const inference::ModelStatistics& end_status) const {
376  TritonClient::ServerSideStats server_stats;
377 
378  server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
379  server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
380  server_stats.success_count_ =
381  end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
382  server_stats.cumm_time_ns_ =
383  end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
384  server_stats.queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
385  server_stats.compute_input_time_ns_ =
386  end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
387  server_stats.compute_infer_time_ns_ =
388  end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
389  server_stats.compute_output_time_ns_ =
390  end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
391 
392  return server_stats;
393 }
394 
395 inference::ModelStatistics TritonClient::getServerSideStatus() const {
396  if (verbose_) {
397  inference::ModelStatisticsResponse resp;
398  TRITON_THROW_IF_ERROR(client_->ModelInferenceStatistics(&resp, options_.model_name_, options_.model_version_),
399  "getServerSideStatus(): unable to get model statistics");
400  return *(resp.model_stats().begin());
401  }
402  return inference::ModelStatistics{};
403 }
404 
405 //for fillDescriptions
407  edm::ParameterSetDescription descClient;
408  fillBasePSetDescription(descClient);
409  descClient.add<std::string>("modelName");
410  descClient.add<std::string>("modelVersion", "");
411  descClient.add<edm::FileInPath>("modelConfigPath");
412  //server parameters should not affect the physics results
413  descClient.addUntracked<std::string>("preferredServer", "");
414  descClient.addUntracked<unsigned>("timeout");
415  descClient.addUntracked<bool>("useSharedMemory", true);
416  descClient.addUntracked<std::string>("compression", "");
417  descClient.addUntracked<std::vector<std::string>>("outputs", {});
418  iDesc.add<edm::ParameterSetDescription>("Client", descClient);
419 }
bool verbose() const
Definition: TritonClient.h:40
const std::string & pid() const
#define CMS_SA_ALLOW
bool setBatchSize(unsigned bsize)
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
~TritonClient() override
#define TRITON_THROW_IF_ERROR(X, MSG)
Definition: triton_utils.h:75
bool verbose
std::vector< const triton::client::InferRequestedOutput * > outputsTriton_
Definition: TritonClient.h:76
void setMode(SonicMode mode)
std::unique_ptr< triton::client::InferenceServerGrpcClient > client_
Definition: TritonClient.h:78
unsigned maxBatchSize_
Definition: TritonClient.h:65
TritonClient(const edm::ParameterSet &params, const std::string &debugName)
Definition: TritonClient.cc:39
std::string debugName_
void finish(bool success, std::exception_ptr eptr=std::exception_ptr{})
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
TritonServerType serverType_
Definition: TritonClient.h:70
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e< void, edm::EventID const &, edm::Timestamp const & > We also list in braces which AR_WATCH_USING_METHOD_ is used for those or
Definition: Activities.doc:12
bool handle_exception(F &&call)
unsigned batchSize_
Definition: TritonClient.h:66
grpc_compression_algorithm compressionAlgo_
Definition: TritonClient.h:71
ParameterDescriptionBase * add(U const &iLabel, T const &value)
static void fillBasePSetDescription(edm::ParameterSetDescription &desc, bool allowRetry=true)
inference::ModelStatistics getServerSideStatus() const
Log< level::Info, false > LogInfo
triton::client::Headers headers_
Definition: TritonClient.h:72
unsigned long long uint64_t
Definition: Time.h:13
std::vector< triton::client::InferInput * > inputsTriton_
Definition: TritonClient.h:75
tuple msg
Definition: mps_check.py:286
triton::client::InferOptions options_
Definition: TritonClient.h:80
void evaluate() override
Server serverInfo(const std::string &model, const std::string &preferred="") const
void getResults(std::shared_ptr< triton::client::InferResult > results)
void reportServerSideStats(const ServerSideStats &stats) const
void reset() override
std::string fullDebugName_
static void fillPSetDescription(edm::ParameterSetDescription &iDesc)
Log< level::Warning, false > LogWarning
static uInt32 F(BLOWFISH_CTX *ctx, uInt32 x)
Definition: blowfish.cc:163
std::string printColl(const C &coll, const std::string &delim=", ")
Definition: triton_utils.cc:9