CMS 3D CMS Logo

TritonClient.cc
Go to the documentation of this file.
5 
6 #include "grpc_client.h"
7 #include "grpc_service.pb.h"
8 
9 #include <string>
10 #include <cmath>
11 #include <chrono>
12 #include <exception>
13 #include <sstream>
14 #include <utility>
15 #include <tuple>
16 
17 namespace ni = nvidia::inferenceserver;
18 namespace nic = ni::client;
19 
20 //based on https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/examples/simple_grpc_async_infer_client.cc
21 //and https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/perf_client.cc
22 
25  verbose_(params.getUntrackedParameter<bool>("verbose")),
26  options_(params.getParameter<std::string>("modelName")) {
27  clientName_ = "TritonClient";
28  //will get overwritten later, just used in constructor
30 
31  //connect to the server
32  //TODO: add SSL options
33  std::string url(params.getUntrackedParameter<std::string>("address") + ":" +
34  std::to_string(params.getUntrackedParameter<unsigned>("port")));
35  triton_utils::throwIfError(nic::InferenceServerGrpcClient::Create(&client_, url, false),
36  "TritonClient(): unable to create inference context");
37 
38  //set options
39  options_.model_version_ = params.getParameter<std::string>("modelVersion");
40  //convert seconds to microseconds
41  options_.client_timeout_ = params.getUntrackedParameter<unsigned>("timeout") * 1e6;
42 
43  //config needed for batch size
44  inference::ModelConfigResponse modelConfigResponse;
45  triton_utils::throwIfError(client_->ModelConfig(&modelConfigResponse, options_.model_name_, options_.model_version_),
46  "TritonClient(): unable to get model config");
47  inference::ModelConfig modelConfig(modelConfigResponse.config());
48 
49  //check batch size limitations (after i/o setup)
50  //triton uses max batch size = 0 to denote a model that does not support batching
51  //but for models that do support batching, a given event may set batch size 0 to indicate no valid input is present
52  //so set the local max to 1 and keep track of "no batch" case
53  maxBatchSize_ = modelConfig.max_batch_size();
54  noBatch_ = maxBatchSize_ == 0;
56 
57  //get model info
58  inference::ModelMetadataResponse modelMetadata;
59  triton_utils::throwIfError(client_->ModelMetadata(&modelMetadata, options_.model_name_, options_.model_version_),
60  "TritonClient(): unable to get model metadata");
61 
62  //get input and output (which know their sizes)
63  const auto& nicInputs = modelMetadata.inputs();
64  const auto& nicOutputs = modelMetadata.outputs();
65 
66  //report all model errors at once
67  std::stringstream msg;
68  std::string msg_str;
69 
70  //currently no use case is foreseen for a model with zero inputs or outputs
71  if (nicInputs.empty())
72  msg << "Model on server appears malformed (zero inputs)\n";
73 
74  if (nicOutputs.empty())
75  msg << "Model on server appears malformed (zero outputs)\n";
76 
77  //stop if errors
78  msg_str = msg.str();
79  if (!msg_str.empty())
80  throw cms::Exception("ModelErrors") << msg_str;
81 
82  //setup input map
83  std::stringstream io_msg;
84  if (verbose_)
85  io_msg << "Model inputs: "
86  << "\n";
87  inputsTriton_.reserve(nicInputs.size());
88  for (const auto& nicInput : nicInputs) {
89  const auto& iname = nicInput.name();
90  auto [curr_itr, success] = input_.emplace(
91  std::piecewise_construct, std::forward_as_tuple(iname), std::forward_as_tuple(iname, nicInput, noBatch_));
92  auto& curr_input = curr_itr->second;
93  inputsTriton_.push_back(curr_input.data());
94  if (verbose_) {
95  io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
96  << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
97  }
98  }
99 
100  //allow selecting only some outputs from server
101  const auto& v_outputs = params.getUntrackedParameter<std::vector<std::string>>("outputs");
102  std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
103 
104  //setup output map
105  if (verbose_)
106  io_msg << "Model outputs: "
107  << "\n";
108  outputsTriton_.reserve(nicOutputs.size());
109  for (const auto& nicOutput : nicOutputs) {
110  const auto& oname = nicOutput.name();
111  if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
112  continue;
113  auto [curr_itr, success] = output_.emplace(
114  std::piecewise_construct, std::forward_as_tuple(oname), std::forward_as_tuple(oname, nicOutput, noBatch_));
115  auto& curr_output = curr_itr->second;
116  outputsTriton_.push_back(curr_output.data());
117  if (verbose_) {
118  io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
119  << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
120  }
121  if (!s_outputs.empty())
122  s_outputs.erase(oname);
123  }
124 
125  //check if any requested outputs were not available
126  if (!s_outputs.empty())
127  throw cms::Exception("MissingOutput")
128  << "Some requested outputs were not available on the server: " << triton_utils::printColl(s_outputs);
129 
130  //check requested batch size and propagate to inputs and outputs
131  setBatchSize(params.getUntrackedParameter<unsigned>("batchSize"));
132 
133  //print model info
134  std::stringstream model_msg;
135  if (verbose_) {
136  model_msg << "Model name: " << options_.model_name_ << "\n"
137  << "Model version: " << options_.model_version_ << "\n"
138  << "Model max batch size: " << (noBatch_ ? 0 : maxBatchSize_) << "\n";
139  edm::LogInfo(fullDebugName_) << model_msg.str() << io_msg.str();
140  }
141 }
142 
143 bool TritonClient::setBatchSize(unsigned bsize) {
144  if (bsize > maxBatchSize_) {
145  edm::LogWarning(fullDebugName_) << "Requested batch size " << bsize << " exceeds server-specified max batch size "
146  << maxBatchSize_ << ". Batch size will remain as" << batchSize_;
147  return false;
148  } else {
149  batchSize_ = bsize;
150  //set for input and output
151  for (auto& element : input_) {
152  element.second.setBatchSize(bsize);
153  }
154  for (auto& element : output_) {
155  element.second.setBatchSize(bsize);
156  }
157  return true;
158  }
159 }
160 
162  for (auto& element : input_) {
163  element.second.reset();
164  }
165  for (auto& element : output_) {
166  element.second.reset();
167  }
168 }
169 
170 bool TritonClient::getResults(std::shared_ptr<nic::InferResult> results) {
171  for (auto& [oname, output] : output_) {
172  //set shape here before output becomes const
173  if (output.variableDims()) {
174  std::vector<int64_t> tmp_shape;
175  bool status = triton_utils::warnIfError(results->Shape(oname, &tmp_shape),
176  "getResults(): unable to get output shape for " + oname);
177  if (!status)
178  return status;
179  output.setShape(tmp_shape, false);
180  }
181  //extend lifetime
182  output.setResult(results);
183  }
184 
185  return true;
186 }
187 
188 //default case for sync and pseudo async
190  //in case there is nothing to process
191  if (batchSize_ == 0) {
192  finish(true);
193  return;
194  }
195 
196  // Get the status of the server prior to the request being made.
197  const auto& start_status = getServerSideStatus();
198 
199  if (mode_ == SonicMode::Async) {
200  //non-blocking call
203  client_->AsyncInfer(
204  [t1, start_status, this](nic::InferResult* results) {
205  //get results
206  std::shared_ptr<nic::InferResult> results_ptr(results);
207  bool status = triton_utils::warnIfError(results_ptr->RequestStatus(), "evaluate(): unable to get result");
208  if (!status) {
209  finish(false);
210  return;
211  }
213 
214  if (!debugName_.empty())
216  << "Remote time: " << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
217 
218  const auto& end_status = getServerSideStatus();
219 
220  if (verbose()) {
221  const auto& stats = summarizeServerStats(start_status, end_status);
223  }
224 
225  //check result
226  status = getResults(results_ptr);
227 
228  //finish
229  finish(status);
230  },
231  options_,
234  "evaluate(): unable to launch async run");
235 
236  //if AsyncRun failed, finish() wasn't called
237  if (!status)
238  finish(false);
239  } else {
240  //blocking call
242  nic::InferResult* results;
243  bool status = triton_utils::warnIfError(client_->Infer(&results, options_, inputsTriton_, outputsTriton_),
244  "evaluate(): unable to run and/or get result");
245  if (!status) {
246  finish(false);
247  return;
248  }
249 
251  if (!debugName_.empty())
252  edm::LogInfo(fullDebugName_) << "Remote time: "
253  << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
254 
255  const auto& end_status = getServerSideStatus();
256 
257  if (verbose()) {
258  const auto& stats = summarizeServerStats(start_status, end_status);
259  reportServerSideStats(stats);
260  }
261 
262  std::shared_ptr<nic::InferResult> results_ptr(results);
263  status = getResults(results_ptr);
264 
265  finish(status);
266  }
267 }
268 
270  std::stringstream msg;
271 
272  // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc
273  const uint64_t count = stats.success_count_;
274  msg << " Inference count: " << stats.inference_count_ << "\n";
275  msg << " Execution count: " << stats.execution_count_ << "\n";
276  msg << " Successful request count: " << count << "\n";
277 
278  if (count > 0) {
279  auto get_avg_us = [count](uint64_t tval) {
280  constexpr uint64_t us_to_ns = 1000;
281  return tval / us_to_ns / count;
282  };
283 
284  const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
285  const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
286  const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
287  const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
288  const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
289  const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
290  const uint64_t overhead =
291  (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
292 
293  msg << " Avg request latency: " << cumm_avg_us << " usec"
294  << "\n"
295  << " (overhead " << overhead << " usec + "
296  << "queue " << queue_avg_us << " usec + "
297  << "compute input " << compute_input_avg_us << " usec + "
298  << "compute infer " << compute_infer_avg_us << " usec + "
299  << "compute output " << compute_output_avg_us << " usec)" << std::endl;
300  }
301 
302  if (!debugName_.empty())
303  edm::LogInfo(fullDebugName_) << msg.str();
304 }
305 
306 TritonClient::ServerSideStats TritonClient::summarizeServerStats(const inference::ModelStatistics& start_status,
307  const inference::ModelStatistics& end_status) const {
308  TritonClient::ServerSideStats server_stats;
309 
310  server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
311  server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
312  server_stats.success_count_ =
313  end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
314  server_stats.cumm_time_ns_ =
315  end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
316  server_stats.queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
317  server_stats.compute_input_time_ns_ =
318  end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
319  server_stats.compute_infer_time_ns_ =
320  end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
321  server_stats.compute_output_time_ns_ =
322  end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
323 
324  return server_stats;
325 }
326 
327 inference::ModelStatistics TritonClient::getServerSideStatus() const {
328  if (verbose_) {
329  inference::ModelStatisticsResponse resp;
330  triton_utils::warnIfError(client_->ModelInferenceStatistics(&resp, options_.model_name_, options_.model_version_),
331  "getServerSideStatus(): unable to get model statistics");
332  return *(resp.model_stats().begin());
333  }
334  return inference::ModelStatistics{};
335 }
336 
337 //for fillDescriptions
339  edm::ParameterSetDescription descClient;
340  fillBasePSetDescription(descClient);
341  descClient.add<std::string>("modelName");
342  descClient.add<std::string>("modelVersion", "");
343  //server parameters should not affect the physics results
344  descClient.addUntracked<unsigned>("batchSize");
345  descClient.addUntracked<std::string>("address");
346  descClient.addUntracked<unsigned>("port");
347  descClient.addUntracked<unsigned>("timeout");
348  descClient.addUntracked<bool>("verbose", false);
349  descClient.addUntracked<std::vector<std::string>>("outputs", {});
350  iDesc.add<edm::ParameterSetDescription>("Client", descClient);
351 }
TritonClient::maxBatchSize_
unsigned maxBatchSize_
Definition: TritonClient.h:56
RandomServiceHelper.t2
t2
Definition: RandomServiceHelper.py:257
TritonClient::ServerSideStats::inference_count_
uint64_t inference_count_
Definition: TritonClient.h:21
nvidia::inferenceserver
Definition: TritonData.cc:14
SonicClientTypes< TritonInputMap, TritonOutputMap >::input_
Input input_
Definition: SonicClientTypes.h:19
TritonClient::evaluate
void evaluate() override
Definition: TritonClient.cc:189
TritonClient::noBatch_
bool noBatch_
Definition: TritonClient.h:58
electrons_cff.bool
bool
Definition: electrons_cff.py:393
relmon_authenticated_wget.url
url
Definition: relmon_authenticated_wget.py:22
edm::ParameterSetDescription::add
ParameterDescriptionBase * add(U const &iLabel, T const &value)
Definition: ParameterSetDescription.h:95
SonicClientBase::fillBasePSetDescription
static void fillBasePSetDescription(edm::ParameterSetDescription &desc, bool allowRetry=true)
Definition: SonicClientBase.cc:71
MessageLogger.h
TritonClient::client_
std::unique_ptr< nvidia::inferenceserver::client::InferenceServerGrpcClient > client_
Definition: TritonClient.h:65
TritonClient::outputsTriton_
std::vector< const nvidia::inferenceserver::client::InferRequestedOutput * > outputsTriton_
Definition: TritonClient.h:63
submitPVValidationJobs.now
now
Definition: submitPVValidationJobs.py:639
mps_update.status
status
Definition: mps_update.py:69
CalibrationSummaryClient_cfi.params
params
Definition: CalibrationSummaryClient_cfi.py:14
SonicClientTypes< TritonInputMap, TritonOutputMap >::output_
Output output_
Definition: SonicClientTypes.h:20
TritonClient.h
SonicClientTypes< TritonInputMap, TritonOutputMap >::output
const Output & output() const
Definition: SonicClientTypes.h:16
TritonClient::ServerSideStats
Definition: TritonClient.h:20
TritonClient::summarizeServerStats
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
Definition: TritonClient.cc:306
edm::ParameterSetDescription
Definition: ParameterSetDescription.h:52
TritonClient::setBatchSize
bool setBatchSize(unsigned bsize)
Definition: TritonClient.cc:143
SonicClientBase::finish
void finish(bool success, std::exception_ptr eptr=std::exception_ptr{})
Definition: SonicClientBase.cc:42
bookConverter.results
results
Definition: bookConverter.py:144
mps_check.msg
tuple msg
Definition: mps_check.py:285
edm::LogInfo
Log< level::Info, false > LogInfo
Definition: MessageLogger.h:125
triton_utils::warnIfError
bool warnIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:25
edm::LogWarning
Log< level::Warning, false > LogWarning
Definition: MessageLogger.h:122
triton_utils::throwIfError
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:20
RandomServiceHelper.t1
t1
Definition: RandomServiceHelper.py:256
TritonClient::ServerSideStats::compute_infer_time_ns_
uint64_t compute_infer_time_ns_
Definition: TritonClient.h:27
HLTObjectMonitor_Client_cff.client
client
Definition: HLTObjectMonitor_Client_cff.py:6
TritonClient::ServerSideStats::compute_input_time_ns_
uint64_t compute_input_time_ns_
Definition: TritonClient.h:26
summarizeEdmComparisonLogfiles.success
success
Definition: summarizeEdmComparisonLogfiles.py:115
submitPVResolutionJobs.count
count
Definition: submitPVResolutionJobs.py:352
verbose
static constexpr int verbose
Definition: HLTExoticaSubAnalysis.cc:25
TritonClient::inputsTriton_
std::vector< nvidia::inferenceserver::client::InferInput * > inputsTriton_
Definition: TritonClient.h:62
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
edm::ParameterSetDescription::addUntracked
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
Definition: ParameterSetDescription.h:100
TritonClient::verbose_
bool verbose_
Definition: TritonClient.h:59
SonicClientBase::mode_
SonicMode mode_
Definition: SonicClientBase.h:54
edm::ParameterSet
Definition: ParameterSet.h:47
SonicMode::Async
TritonClient::reportServerSideStats
void reportServerSideStats(const ServerSideStats &stats) const
Definition: TritonClient.cc:269
SonicClientBase::fullDebugName_
std::string fullDebugName_
Definition: SonicClientBase.h:60
SiStripPI::max
Definition: SiStripPayloadInspectorHelper.h:169
TritonClient::ServerSideStats::execution_count_
uint64_t execution_count_
Definition: TritonClient.h:22
TritonClient::ServerSideStats::success_count_
uint64_t success_count_
Definition: TritonClient.h:23
TritonClient::getResults
bool getResults(std::shared_ptr< nvidia::inferenceserver::client::InferResult > results)
Definition: TritonClient.cc:170
SonicClientBase::debugName_
std::string debugName_
Definition: SonicClientBase.h:60
SonicClient
Definition: SonicClient.h:9
TritonClient::TritonClient
TritonClient(const edm::ParameterSet &params)
Definition: TritonClient.cc:23
dqmMemoryStats.stats
stats
Definition: dqmMemoryStats.py:134
TritonClient::ServerSideStats::compute_output_time_ns_
uint64_t compute_output_time_ns_
Definition: TritonClient.h:28
TritonClient::batchSize_
unsigned batchSize_
Definition: TritonClient.h:57
std
Definition: JetResolutionObject.h:76
TritonClient::ServerSideStats::queue_time_ns_
uint64_t queue_time_ns_
Definition: TritonClient.h:25
TritonClient::fillPSetDescription
static void fillPSetDescription(edm::ParameterSetDescription &iDesc)
Definition: TritonClient.cc:338
triton_utils.h
TritonClient::verbose
bool verbose() const
Definition: TritonClient.h:36
TritonClient::options_
nvidia::inferenceserver::client::InferOptions options_
Definition: TritonClient.h:67
Exception.h
cond::uint64_t
unsigned long long uint64_t
Definition: Time.h:13
cms::Exception
Definition: Exception.h:70
TritonClient::getServerSideStatus
inference::ModelStatistics getServerSideStatus() const
Definition: TritonClient.cc:327
TritonClient::reset
void reset() override
Definition: TritonClient.cc:161
SonicClientBase::clientName_
std::string clientName_
Definition: SonicClientBase.h:60
TritonClient::ServerSideStats::cumm_time_ns_
uint64_t cumm_time_ns_
Definition: TritonClient.h:24
edm::Log
Definition: MessageLogger.h:70
heppy_report.oname
oname
Definition: heppy_report.py:58
triton_utils::printColl
std::string printColl(const C &coll, const std::string &delim=", ")
Definition: triton_utils.cc:11