CMS 3D CMS Logo

TritonClient.cc
Go to the documentation of this file.
9 
10 #include "grpc_client.h"
11 #include "grpc_service.pb.h"
12 
13 #include <string>
14 #include <cmath>
15 #include <exception>
16 #include <sstream>
17 #include <utility>
18 #include <tuple>
19 
20 namespace tc = triton::client;
21 
22 namespace {
23  grpc_compression_algorithm getCompressionAlgo(const std::string& name) {
24  if (name.empty() or name.compare("none") == 0)
25  return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
26  else if (name.compare("deflate") == 0)
27  return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
28  else if (name.compare("gzip") == 0)
29  return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
30  else
31  throw cms::Exception("GrpcCompression")
32  << "Unknown compression algorithm requested: " << name << " (choices: none, deflate, gzip)";
33  }
34 } // namespace
35 
36 //based on https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/examples/simple_grpc_async_infer_client.cc
37 //and https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/perf_client.cc
38 
40  : SonicClient(params, debugName, "TritonClient"),
41  verbose_(params.getUntrackedParameter<bool>("verbose")),
42  useSharedMemory_(params.getUntrackedParameter<bool>("useSharedMemory")),
43  compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter<std::string>("compression"))),
44  options_(params.getParameter<std::string>("modelName")) {
45  //get appropriate server for this model
47  const auto& server =
48  ts->serverInfo(options_.model_name_, params.getUntrackedParameter<std::string>("preferredServer"));
49  serverType_ = server.type;
50  if (verbose_)
51  edm::LogInfo(fullDebugName_) << "Using server: " << server.url;
52  //enforce sync mode for fallback CPU server to avoid contention
53  //todo: could enforce async mode otherwise (unless mode was specified by user?)
56 
57  //connect to the server
59  tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions),
60  "TritonClient(): unable to create inference context");
61 
62  //set options
63  options_.model_version_ = params.getParameter<std::string>("modelVersion");
64  //convert seconds to microseconds
65  options_.client_timeout_ = params.getUntrackedParameter<unsigned>("timeout") * 1e6;
66 
67  //config needed for batch size
68  inference::ModelConfigResponse modelConfigResponse;
69  triton_utils::throwIfError(client_->ModelConfig(&modelConfigResponse, options_.model_name_, options_.model_version_),
70  "TritonClient(): unable to get model config");
71  inference::ModelConfig modelConfig(modelConfigResponse.config());
72 
73  //check batch size limitations (after i/o setup)
74  //triton uses max batch size = 0 to denote a model that does not support batching
75  //but for models that do support batching, a given event may set batch size 0 to indicate no valid input is present
76  //so set the local max to 1 and keep track of "no batch" case
77  maxBatchSize_ = modelConfig.max_batch_size();
78  noBatch_ = maxBatchSize_ == 0;
80 
81  //get model info
82  inference::ModelMetadataResponse modelMetadata;
83  triton_utils::throwIfError(client_->ModelMetadata(&modelMetadata, options_.model_name_, options_.model_version_),
84  "TritonClient(): unable to get model metadata");
85 
86  //get input and output (which know their sizes)
87  const auto& nicInputs = modelMetadata.inputs();
88  const auto& nicOutputs = modelMetadata.outputs();
89 
90  //report all model errors at once
91  std::stringstream msg;
92  std::string msg_str;
93 
94  //currently no use case is foreseen for a model with zero inputs or outputs
95  if (nicInputs.empty())
96  msg << "Model on server appears malformed (zero inputs)\n";
97 
98  if (nicOutputs.empty())
99  msg << "Model on server appears malformed (zero outputs)\n";
100 
101  //stop if errors
102  msg_str = msg.str();
103  if (!msg_str.empty())
104  throw cms::Exception("ModelErrors") << msg_str;
105 
106  //setup input map
107  std::stringstream io_msg;
108  if (verbose_)
109  io_msg << "Model inputs: "
110  << "\n";
111  inputsTriton_.reserve(nicInputs.size());
112  for (const auto& nicInput : nicInputs) {
113  const auto& iname = nicInput.name();
114  auto [curr_itr, success] = input_.emplace(std::piecewise_construct,
115  std::forward_as_tuple(iname),
116  std::forward_as_tuple(iname, nicInput, this, ts->pid()));
117  auto& curr_input = curr_itr->second;
118  inputsTriton_.push_back(curr_input.data());
119  if (verbose_) {
120  io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
121  << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
122  }
123  }
124 
125  //allow selecting only some outputs from server
126  const auto& v_outputs = params.getUntrackedParameter<std::vector<std::string>>("outputs");
127  std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
128 
129  //setup output map
130  if (verbose_)
131  io_msg << "Model outputs: "
132  << "\n";
133  outputsTriton_.reserve(nicOutputs.size());
134  for (const auto& nicOutput : nicOutputs) {
135  const auto& oname = nicOutput.name();
136  if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
137  continue;
138  auto [curr_itr, success] = output_.emplace(std::piecewise_construct,
139  std::forward_as_tuple(oname),
140  std::forward_as_tuple(oname, nicOutput, this, ts->pid()));
141  auto& curr_output = curr_itr->second;
142  outputsTriton_.push_back(curr_output.data());
143  if (verbose_) {
144  io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
145  << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
146  }
147  if (!s_outputs.empty())
148  s_outputs.erase(oname);
149  }
150 
151  //check if any requested outputs were not available
152  if (!s_outputs.empty())
153  throw cms::Exception("MissingOutput")
154  << "Some requested outputs were not available on the server: " << triton_utils::printColl(s_outputs);
155 
156  //propagate batch size to inputs and outputs
157  setBatchSize(1);
158 
159  //print model info
160  std::stringstream model_msg;
161  if (verbose_) {
162  model_msg << "Model name: " << options_.model_name_ << "\n"
163  << "Model version: " << options_.model_version_ << "\n"
164  << "Model max batch size: " << (noBatch_ ? 0 : maxBatchSize_) << "\n";
165  edm::LogInfo(fullDebugName_) << model_msg.str() << io_msg.str();
166  }
167 }
168 
170  //by default: members of this class destroyed before members of base class
171  //in shared memory case, TritonMemResource (member of TritonData) unregisters from client_ in its destructor
172  //but input/output objects are member of base class, so destroyed after client_ (member of this class)
173  //therefore, clear the maps here
174  input_.clear();
175  output_.clear();
176 }
177 
178 bool TritonClient::setBatchSize(unsigned bsize) {
179  if (bsize > maxBatchSize_) {
180  edm::LogWarning(fullDebugName_) << "Requested batch size " << bsize << " exceeds server-specified max batch size "
181  << maxBatchSize_ << ". Batch size will remain as" << batchSize_;
182  return false;
183  } else {
184  batchSize_ = bsize;
185  //set for input and output
186  for (auto& element : input_) {
187  element.second.setBatchSize(bsize);
188  }
189  for (auto& element : output_) {
190  element.second.setBatchSize(bsize);
191  }
192  return true;
193  }
194 }
195 
197  for (auto& element : input_) {
198  element.second.reset();
199  }
200  for (auto& element : output_) {
201  element.second.reset();
202  }
203 }
204 
205 template <typename F>
207  //caught exceptions will be propagated to edm::WaitingTaskWithArenaHolder
208  CMS_SA_ALLOW try {
209  call();
210  return true;
211  }
212  //TritonExceptions are intended/expected to be recoverable, i.e. retries should be allowed
213  catch (TritonException& e) {
214  e.convertToWarning();
215  finish(false);
216  return false;
217  }
218  //other exceptions are not: execution should stop if they are encountered
219  catch (...) {
220  finish(false, std::current_exception());
221  return false;
222  }
223 }
224 
225 void TritonClient::getResults(std::shared_ptr<tc::InferResult> results) {
226  for (auto& [oname, output] : output_) {
227  //set shape here before output becomes const
228  if (output.variableDims()) {
229  std::vector<int64_t> tmp_shape;
230  triton_utils::throwIfError(results->Shape(oname, &tmp_shape),
231  "getResults(): unable to get output shape for " + oname);
232  output.setShape(tmp_shape);
233  output.computeSizes();
234  }
235  //extend lifetime
236  output.setResult(results);
237  }
238 }
239 
240 //default case for sync and pseudo async
242  //in case there is nothing to process
243  if (batchSize_ == 0) {
244  finish(true);
245  return;
246  }
247 
248  //set up shared memory for output
249  auto success = handle_exception([&]() {
250  for (auto& element : output_) {
251  element.second.prepare();
252  }
253  });
254  if (!success)
255  return;
256 
257  // Get the status of the server prior to the request being made.
258  inference::ModelStatistics start_status;
259  success = handle_exception([&]() {
260  if (verbose())
261  start_status = getServerSideStatus();
262  });
263  if (!success)
264  return;
265 
266  if (mode_ == SonicMode::Async) {
267  //non-blocking call
268  success = handle_exception([&]() {
270  [start_status, this](tc::InferResult* results) {
271  //get results
272  std::shared_ptr<tc::InferResult> results_ptr(results);
273  auto success = handle_exception([&]() {
274  triton_utils::throwIfError(results_ptr->RequestStatus(),
275  "evaluate(): unable to get result");
276  });
277  if (!success)
278  return;
279 
280  if (verbose()) {
281  inference::ModelStatistics end_status;
282  success = handle_exception([&]() { end_status = getServerSideStatus(); });
283  if (!success)
284  return;
285 
286  const auto& stats = summarizeServerStats(start_status, end_status);
288  }
289 
290  //check result
291  success = handle_exception([&]() { getResults(results_ptr); });
292  if (!success)
293  return;
294 
295  //finish
296  finish(true);
297  },
298  options_,
301  headers_,
303  "evaluate(): unable to launch async run");
304  });
305  if (!success)
306  return;
307  } else {
308  //blocking call
309  tc::InferResult* results;
310  success = handle_exception([&]() {
312  client_->Infer(&results, options_, inputsTriton_, outputsTriton_, headers_, compressionAlgo_),
313  "evaluate(): unable to run and/or get result");
314  });
315  if (!success)
316  return;
317 
318  if (verbose()) {
319  inference::ModelStatistics end_status;
320  success = handle_exception([&]() { end_status = getServerSideStatus(); });
321  if (!success)
322  return;
323 
324  const auto& stats = summarizeServerStats(start_status, end_status);
325  reportServerSideStats(stats);
326  }
327 
328  std::shared_ptr<tc::InferResult> results_ptr(results);
329  success = handle_exception([&]() { getResults(results_ptr); });
330  if (!success)
331  return;
332 
333  finish(true);
334  }
335 }
336 
338  std::stringstream msg;
339 
340  // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc
341  const uint64_t count = stats.success_count_;
342  msg << " Inference count: " << stats.inference_count_ << "\n";
343  msg << " Execution count: " << stats.execution_count_ << "\n";
344  msg << " Successful request count: " << count << "\n";
345 
346  if (count > 0) {
347  auto get_avg_us = [count](uint64_t tval) {
348  constexpr uint64_t us_to_ns = 1000;
349  return tval / us_to_ns / count;
350  };
351 
352  const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
353  const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
354  const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
355  const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
356  const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
357  const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
358  const uint64_t overhead =
359  (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
360 
361  msg << " Avg request latency: " << cumm_avg_us << " usec"
362  << "\n"
363  << " (overhead " << overhead << " usec + "
364  << "queue " << queue_avg_us << " usec + "
365  << "compute input " << compute_input_avg_us << " usec + "
366  << "compute infer " << compute_infer_avg_us << " usec + "
367  << "compute output " << compute_output_avg_us << " usec)" << std::endl;
368  }
369 
370  if (!debugName_.empty())
371  edm::LogInfo(fullDebugName_) << msg.str();
372 }
373 
374 TritonClient::ServerSideStats TritonClient::summarizeServerStats(const inference::ModelStatistics& start_status,
375  const inference::ModelStatistics& end_status) const {
376  TritonClient::ServerSideStats server_stats;
377 
378  server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
379  server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
380  server_stats.success_count_ =
381  end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
382  server_stats.cumm_time_ns_ =
383  end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
384  server_stats.queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
385  server_stats.compute_input_time_ns_ =
386  end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
387  server_stats.compute_infer_time_ns_ =
388  end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
389  server_stats.compute_output_time_ns_ =
390  end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
391 
392  return server_stats;
393 }
394 
395 inference::ModelStatistics TritonClient::getServerSideStatus() const {
396  if (verbose_) {
397  inference::ModelStatisticsResponse resp;
398  triton_utils::throwIfError(client_->ModelInferenceStatistics(&resp, options_.model_name_, options_.model_version_),
399  "getServerSideStatus(): unable to get model statistics");
400  return *(resp.model_stats().begin());
401  }
402  return inference::ModelStatistics{};
403 }
404 
405 //for fillDescriptions
407  edm::ParameterSetDescription descClient;
408  fillBasePSetDescription(descClient);
409  descClient.add<std::string>("modelName");
410  descClient.add<std::string>("modelVersion", "");
411  descClient.add<edm::FileInPath>("modelConfigPath");
412  //server parameters should not affect the physics results
413  descClient.addUntracked<std::string>("preferredServer", "");
414  descClient.addUntracked<unsigned>("timeout");
415  descClient.addUntracked<bool>("verbose", false);
416  descClient.addUntracked<bool>("useSharedMemory", true);
417  descClient.addUntracked<std::string>("compression", "");
418  descClient.addUntracked<std::vector<std::string>>("outputs", {});
419  iDesc.add<edm::ParameterSetDescription>("Client", descClient);
420 }
TritonClient::maxBatchSize_
unsigned maxBatchSize_
Definition: TritonClient.h:65
TritonClient::ServerSideStats::inference_count_
uint64_t inference_count_
Definition: TritonClient.h:22
SonicClientTypes< TritonInputMap, TritonOutputMap >::input_
Input input_
Definition: SonicClientTypes.h:19
TritonClient::evaluate
void evaluate() override
Definition: TritonClient.cc:241
TritonClient::noBatch_
bool noBatch_
Definition: TritonClient.h:67
electrons_cff.bool
bool
Definition: electrons_cff.py:366
edm::ParameterSetDescription::add
ParameterDescriptionBase * add(U const &iLabel, T const &value)
Definition: ParameterSetDescription.h:95
TritonClient::outputsTriton_
std::vector< const triton::client::InferRequestedOutput * > outputsTriton_
Definition: TritonClient.h:76
SonicClientBase::fillBasePSetDescription
static void fillBasePSetDescription(edm::ParameterSetDescription &desc, bool allowRetry=true)
Definition: SonicClientBase.cc:73
MessageLogger.h
CalibrationSummaryClient_cfi.params
params
Definition: CalibrationSummaryClient_cfi.py:14
SonicClientTypes< TritonInputMap, TritonOutputMap >::output_
Output output_
Definition: SonicClientTypes.h:20
TritonClient.h
SonicClientTypes< TritonInputMap, TritonOutputMap >::output
const Output & output() const
Definition: SonicClientTypes.h:16
TritonClient::ServerSideStats
Definition: TritonClient.h:21
TritonClient::summarizeServerStats
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
Definition: TritonClient.cc:374
edm::ParameterSetDescription
Definition: ParameterSetDescription.h:52
TritonClient::setBatchSize
bool setBatchSize(unsigned bsize)
Definition: TritonClient.cc:178
SonicClientBase::finish
void finish(bool success, std::exception_ptr eptr=std::exception_ptr{})
Definition: SonicClientBase.cc:45
bookConverter.results
results
Definition: bookConverter.py:144
mps_check.msg
tuple msg
Definition: mps_check.py:285
edm::LogInfo
Log< level::Info, false > LogInfo
Definition: MessageLogger.h:125
edm::LogWarning
Log< level::Warning, false > LogWarning
Definition: MessageLogger.h:122
CMS_SA_ALLOW
#define CMS_SA_ALLOW
Definition: thread_safety_macros.h:5
TritonClient::~TritonClient
~TritonClient() override
Definition: TritonClient.cc:169
F
static uInt32 F(BLOWFISH_CTX *ctx, uInt32 x)
Definition: blowfish.cc:163
edm::FileInPath
Definition: FileInPath.h:61
triton_utils::throwIfError
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:21
TritonClient::client_
std::unique_ptr< triton::client::InferenceServerGrpcClient > client_
Definition: TritonClient.h:78
TritonClient::ServerSideStats::compute_infer_time_ns_
uint64_t compute_infer_time_ns_
Definition: TritonClient.h:28
HLTObjectMonitor_Client_cff.client
client
Definition: HLTObjectMonitor_Client_cff.py:6
TritonClient::ServerSideStats::compute_input_time_ns_
uint64_t compute_input_time_ns_
Definition: TritonClient.h:27
TritonClient::compressionAlgo_
grpc_compression_algorithm compressionAlgo_
Definition: TritonClient.h:71
Service.h
summarizeEdmComparisonLogfiles.success
success
Definition: summarizeEdmComparisonLogfiles.py:114
submitPVResolutionJobs.count
count
Definition: submitPVResolutionJobs.py:352
verbose
static constexpr int verbose
Definition: HLTExoticaSubAnalysis.cc:25
TritonService.h
FileInPath.h
TritonClient::serverType_
TritonServerType serverType_
Definition: TritonClient.h:70
SonicMode::Sync
TritonClient::TritonClient
TritonClient(const edm::ParameterSet &params, const std::string &debugName)
Definition: TritonClient.cc:39
edm::ParameterSetDescription::addUntracked
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
Definition: ParameterSetDescription.h:100
TritonClient::inputsTriton_
std::vector< triton::client::InferInput * > inputsTriton_
Definition: TritonClient.h:75
TritonClient::verbose_
bool verbose_
Definition: TritonClient.h:68
TritonServerType::LocalCPU
SonicClientBase::mode_
SonicMode mode_
Definition: SonicClientBase.h:54
edm::ParameterSet
Definition: ParameterSet.h:47
SonicMode::Async
TritonClient::reportServerSideStats
void reportServerSideStats(const ServerSideStats &stats) const
Definition: TritonClient.cc:337
SonicClientBase::fullDebugName_
std::string fullDebugName_
Definition: SonicClientBase.h:60
SiStripPI::max
Definition: SiStripPayloadInspectorHelper.h:169
contentValuesFiles.server
server
Definition: contentValuesFiles.py:37
TritonClient::ServerSideStats::execution_count_
uint64_t execution_count_
Definition: TritonClient.h:23
TritonClient::ServerSideStats::success_count_
uint64_t success_count_
Definition: TritonClient.h:24
SonicClientBase::debugName_
std::string debugName_
Definition: SonicClientBase.h:60
edm::Service
Definition: Service.h:30
SonicClient
Definition: SonicClient.h:9
dqmMemoryStats.stats
stats
Definition: dqmMemoryStats.py:134
TritonClient::ServerSideStats::compute_output_time_ns_
uint64_t compute_output_time_ns_
Definition: TritonClient.h:29
TritonService::serverInfo
Server serverInfo(const std::string &model, const std::string &preferred="") const
Definition: TritonService.cc:178
AlCaHLTBitMon_QueryRunRegistry.string
string string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
TritonClient::batchSize_
unsigned batchSize_
Definition: TritonClient.h:66
TritonService::pid
const std::string & pid() const
Definition: TritonService.h:103
TritonClient::handle_exception
bool handle_exception(F &&call)
Definition: TritonClient.cc:206
std
Definition: JetResolutionObject.h:76
TritonClient::ServerSideStats::queue_time_ns_
uint64_t queue_time_ns_
Definition: TritonClient.h:26
TritonClient::fillPSetDescription
static void fillPSetDescription(edm::ParameterSetDescription &iDesc)
Definition: TritonClient.cc:406
TritonException
Definition: TritonException.h:8
triton_utils.h
Exception
Definition: hltDiff.cc:245
TritonClient::headers_
triton::client::Headers headers_
Definition: TritonClient.h:72
TritonClient::verbose
bool verbose() const
Definition: TritonClient.h:40
TritonException.h
Skims_PA_cff.name
name
Definition: Skims_PA_cff.py:17
or
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e< void, edm::EventID const &, edm::Timestamp const & > We also list in braces which AR_WATCH_USING_METHOD_ is used for those or
Definition: Activities.doc:12
Exception.h
cond::uint64_t
unsigned long long uint64_t
Definition: Time.h:13
cms::Exception
Definition: Exception.h:70
TritonClient::getServerSideStatus
inference::ModelStatistics getServerSideStatus() const
Definition: TritonClient.cc:395
TritonClient::getResults
void getResults(std::shared_ptr< triton::client::InferResult > results)
Definition: TritonClient.cc:225
TritonClient::reset
void reset() override
Definition: TritonClient.cc:196
TritonClient::ServerSideStats::cumm_time_ns_
uint64_t cumm_time_ns_
Definition: TritonClient.h:25
edm::Log
Definition: MessageLogger.h:70
heppy_report.oname
oname
Definition: heppy_report.py:57
triton_utils::printColl
std::string printColl(const C &coll, const std::string &delim=", ")
Definition: triton_utils.cc:12
TritonClient::options_
triton::client::InferOptions options_
Definition: TritonClient.h:80
MillePedeFileConverter_cfg.e
e
Definition: MillePedeFileConverter_cfg.py:37
SonicClientBase::setMode
void setMode(SonicMode mode)
Definition: SonicClientBase.cc:26