CMS 3D CMS Logo

TritonClient.cc
Go to the documentation of this file.
8 
9 #include "grpc_client.h"
10 #include "grpc_service.pb.h"
11 
12 #include <string>
13 #include <cmath>
14 #include <chrono>
15 #include <exception>
16 #include <sstream>
17 #include <utility>
18 #include <tuple>
19 
20 namespace ni = nvidia::inferenceserver;
21 namespace nic = ni::client;
22 
23 //based on https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/examples/simple_grpc_async_infer_client.cc
24 //and https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/perf_client.cc
25 
27  : SonicClient(params, debugName, "TritonClient"),
28  verbose_(params.getUntrackedParameter<bool>("verbose")),
29  options_(params.getParameter<std::string>("modelName")) {
30  //get appropriate server for this model
32  const auto& [url, isFallbackCPU] =
33  ts->serverAddress(options_.model_name_, params.getUntrackedParameter<std::string>("preferredServer"));
34  if (verbose_)
35  edm::LogInfo(fullDebugName_) << "Using server: " << url;
36  //enforce sync mode for fallback CPU server to avoid contention
37  //todo: could enforce async mode otherwise (unless mode was specified by user?)
38  if (isFallbackCPU)
40 
41  //connect to the server
42  //TODO: add SSL options
43  triton_utils::throwIfError(nic::InferenceServerGrpcClient::Create(&client_, url, false),
44  "TritonClient(): unable to create inference context");
45 
46  //set options
47  options_.model_version_ = params.getParameter<std::string>("modelVersion");
48  //convert seconds to microseconds
49  options_.client_timeout_ = params.getUntrackedParameter<unsigned>("timeout") * 1e6;
50 
51  //config needed for batch size
52  inference::ModelConfigResponse modelConfigResponse;
53  triton_utils::throwIfError(client_->ModelConfig(&modelConfigResponse, options_.model_name_, options_.model_version_),
54  "TritonClient(): unable to get model config");
55  inference::ModelConfig modelConfig(modelConfigResponse.config());
56 
57  //check batch size limitations (after i/o setup)
58  //triton uses max batch size = 0 to denote a model that does not support batching
59  //but for models that do support batching, a given event may set batch size 0 to indicate no valid input is present
60  //so set the local max to 1 and keep track of "no batch" case
61  maxBatchSize_ = modelConfig.max_batch_size();
62  noBatch_ = maxBatchSize_ == 0;
64 
65  //get model info
66  inference::ModelMetadataResponse modelMetadata;
67  triton_utils::throwIfError(client_->ModelMetadata(&modelMetadata, options_.model_name_, options_.model_version_),
68  "TritonClient(): unable to get model metadata");
69 
70  //get input and output (which know their sizes)
71  const auto& nicInputs = modelMetadata.inputs();
72  const auto& nicOutputs = modelMetadata.outputs();
73 
74  //report all model errors at once
75  std::stringstream msg;
76  std::string msg_str;
77 
78  //currently no use case is foreseen for a model with zero inputs or outputs
79  if (nicInputs.empty())
80  msg << "Model on server appears malformed (zero inputs)\n";
81 
82  if (nicOutputs.empty())
83  msg << "Model on server appears malformed (zero outputs)\n";
84 
85  //stop if errors
86  msg_str = msg.str();
87  if (!msg_str.empty())
88  throw cms::Exception("ModelErrors") << msg_str;
89 
90  //setup input map
91  std::stringstream io_msg;
92  if (verbose_)
93  io_msg << "Model inputs: "
94  << "\n";
95  inputsTriton_.reserve(nicInputs.size());
96  for (const auto& nicInput : nicInputs) {
97  const auto& iname = nicInput.name();
98  auto [curr_itr, success] = input_.emplace(
99  std::piecewise_construct, std::forward_as_tuple(iname), std::forward_as_tuple(iname, nicInput, noBatch_));
100  auto& curr_input = curr_itr->second;
101  inputsTriton_.push_back(curr_input.data());
102  if (verbose_) {
103  io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
104  << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
105  }
106  }
107 
108  //allow selecting only some outputs from server
109  const auto& v_outputs = params.getUntrackedParameter<std::vector<std::string>>("outputs");
110  std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
111 
112  //setup output map
113  if (verbose_)
114  io_msg << "Model outputs: "
115  << "\n";
116  outputsTriton_.reserve(nicOutputs.size());
117  for (const auto& nicOutput : nicOutputs) {
118  const auto& oname = nicOutput.name();
119  if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
120  continue;
121  auto [curr_itr, success] = output_.emplace(
122  std::piecewise_construct, std::forward_as_tuple(oname), std::forward_as_tuple(oname, nicOutput, noBatch_));
123  auto& curr_output = curr_itr->second;
124  outputsTriton_.push_back(curr_output.data());
125  if (verbose_) {
126  io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
127  << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
128  }
129  if (!s_outputs.empty())
130  s_outputs.erase(oname);
131  }
132 
133  //check if any requested outputs were not available
134  if (!s_outputs.empty())
135  throw cms::Exception("MissingOutput")
136  << "Some requested outputs were not available on the server: " << triton_utils::printColl(s_outputs);
137 
138  //propagate batch size to inputs and outputs
139  setBatchSize(1);
140 
141  //print model info
142  std::stringstream model_msg;
143  if (verbose_) {
144  model_msg << "Model name: " << options_.model_name_ << "\n"
145  << "Model version: " << options_.model_version_ << "\n"
146  << "Model max batch size: " << (noBatch_ ? 0 : maxBatchSize_) << "\n";
147  edm::LogInfo(fullDebugName_) << model_msg.str() << io_msg.str();
148  }
149 }
150 
151 bool TritonClient::setBatchSize(unsigned bsize) {
152  if (bsize > maxBatchSize_) {
153  edm::LogWarning(fullDebugName_) << "Requested batch size " << bsize << " exceeds server-specified max batch size "
154  << maxBatchSize_ << ". Batch size will remain as" << batchSize_;
155  return false;
156  } else {
157  batchSize_ = bsize;
158  //set for input and output
159  for (auto& element : input_) {
160  element.second.setBatchSize(bsize);
161  }
162  for (auto& element : output_) {
163  element.second.setBatchSize(bsize);
164  }
165  return true;
166  }
167 }
168 
170  for (auto& element : input_) {
171  element.second.reset();
172  }
173  for (auto& element : output_) {
174  element.second.reset();
175  }
176 }
177 
178 bool TritonClient::getResults(std::shared_ptr<nic::InferResult> results) {
179  for (auto& [oname, output] : output_) {
180  //set shape here before output becomes const
181  if (output.variableDims()) {
182  std::vector<int64_t> tmp_shape;
183  bool status = triton_utils::warnIfError(results->Shape(oname, &tmp_shape),
184  "getResults(): unable to get output shape for " + oname);
185  if (!status)
186  return status;
187  output.setShape(tmp_shape, false);
188  }
189  //extend lifetime
190  output.setResult(results);
191  }
192 
193  return true;
194 }
195 
196 //default case for sync and pseudo async
198  //in case there is nothing to process
199  if (batchSize_ == 0) {
200  finish(true);
201  return;
202  }
203 
204  // Get the status of the server prior to the request being made.
205  const auto& start_status = getServerSideStatus();
206 
207  if (mode_ == SonicMode::Async) {
208  //non-blocking call
211  client_->AsyncInfer(
212  [t1, start_status, this](nic::InferResult* results) {
213  //get results
214  std::shared_ptr<nic::InferResult> results_ptr(results);
215  bool status = triton_utils::warnIfError(results_ptr->RequestStatus(), "evaluate(): unable to get result");
216  if (!status) {
217  finish(false);
218  return;
219  }
221 
222  if (!debugName_.empty())
224  << "Remote time: " << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
225 
226  const auto& end_status = getServerSideStatus();
227 
228  if (verbose()) {
229  const auto& stats = summarizeServerStats(start_status, end_status);
231  }
232 
233  //check result
234  status = getResults(results_ptr);
235 
236  //finish
237  finish(status);
238  },
239  options_,
242  "evaluate(): unable to launch async run");
243 
244  //if AsyncRun failed, finish() wasn't called
245  if (!status)
246  finish(false);
247  } else {
248  //blocking call
250  nic::InferResult* results;
251  bool status = triton_utils::warnIfError(client_->Infer(&results, options_, inputsTriton_, outputsTriton_),
252  "evaluate(): unable to run and/or get result");
253  if (!status) {
254  finish(false);
255  return;
256  }
257 
259  if (!debugName_.empty())
260  edm::LogInfo(fullDebugName_) << "Remote time: "
261  << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
262 
263  const auto& end_status = getServerSideStatus();
264 
265  if (verbose()) {
266  const auto& stats = summarizeServerStats(start_status, end_status);
267  reportServerSideStats(stats);
268  }
269 
270  std::shared_ptr<nic::InferResult> results_ptr(results);
271  status = getResults(results_ptr);
272 
273  finish(status);
274  }
275 }
276 
278  std::stringstream msg;
279 
280  // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc
281  const uint64_t count = stats.success_count_;
282  msg << " Inference count: " << stats.inference_count_ << "\n";
283  msg << " Execution count: " << stats.execution_count_ << "\n";
284  msg << " Successful request count: " << count << "\n";
285 
286  if (count > 0) {
287  auto get_avg_us = [count](uint64_t tval) {
288  constexpr uint64_t us_to_ns = 1000;
289  return tval / us_to_ns / count;
290  };
291 
292  const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
293  const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
294  const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
295  const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
296  const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
297  const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
298  const uint64_t overhead =
299  (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
300 
301  msg << " Avg request latency: " << cumm_avg_us << " usec"
302  << "\n"
303  << " (overhead " << overhead << " usec + "
304  << "queue " << queue_avg_us << " usec + "
305  << "compute input " << compute_input_avg_us << " usec + "
306  << "compute infer " << compute_infer_avg_us << " usec + "
307  << "compute output " << compute_output_avg_us << " usec)" << std::endl;
308  }
309 
310  if (!debugName_.empty())
311  edm::LogInfo(fullDebugName_) << msg.str();
312 }
313 
314 TritonClient::ServerSideStats TritonClient::summarizeServerStats(const inference::ModelStatistics& start_status,
315  const inference::ModelStatistics& end_status) const {
316  TritonClient::ServerSideStats server_stats;
317 
318  server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
319  server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
320  server_stats.success_count_ =
321  end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
322  server_stats.cumm_time_ns_ =
323  end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
324  server_stats.queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
325  server_stats.compute_input_time_ns_ =
326  end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
327  server_stats.compute_infer_time_ns_ =
328  end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
329  server_stats.compute_output_time_ns_ =
330  end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
331 
332  return server_stats;
333 }
334 
335 inference::ModelStatistics TritonClient::getServerSideStatus() const {
336  if (verbose_) {
337  inference::ModelStatisticsResponse resp;
339  client_->ModelInferenceStatistics(&resp, options_.model_name_, options_.model_version_),
340  "getServerSideStatus(): unable to get model statistics");
341  if (success)
342  return *(resp.model_stats().begin());
343  }
344  return inference::ModelStatistics{};
345 }
346 
347 //for fillDescriptions
349  edm::ParameterSetDescription descClient;
350  fillBasePSetDescription(descClient);
351  descClient.add<std::string>("modelName");
352  descClient.add<std::string>("modelVersion", "");
353  descClient.add<edm::FileInPath>("modelConfigPath");
354  //server parameters should not affect the physics results
355  descClient.addUntracked<std::string>("preferredServer", "");
356  descClient.addUntracked<unsigned>("timeout");
357  descClient.addUntracked<bool>("verbose", false);
358  descClient.addUntracked<std::vector<std::string>>("outputs", {});
359  iDesc.add<edm::ParameterSetDescription>("Client", descClient);
360 }
TritonClient::maxBatchSize_
unsigned maxBatchSize_
Definition: TritonClient.h:56
RandomServiceHelper.t2
t2
Definition: RandomServiceHelper.py:257
TritonClient::ServerSideStats::inference_count_
uint64_t inference_count_
Definition: TritonClient.h:21
nvidia::inferenceserver
Definition: TritonData.cc:14
SonicClientTypes< TritonInputMap, TritonOutputMap >::input_
Input input_
Definition: SonicClientTypes.h:19
TritonClient::evaluate
void evaluate() override
Definition: TritonClient.cc:197
TritonClient::noBatch_
bool noBatch_
Definition: TritonClient.h:58
electrons_cff.bool
bool
Definition: electrons_cff.py:366
TritonService::serverAddress
std::pair< std::string, bool > serverAddress(const std::string &model, const std::string &preferred="") const
Definition: TritonService.cc:161
relmon_authenticated_wget.url
url
Definition: relmon_authenticated_wget.py:22
edm::ParameterSetDescription::add
ParameterDescriptionBase * add(U const &iLabel, T const &value)
Definition: ParameterSetDescription.h:95
SonicClientBase::fillBasePSetDescription
static void fillBasePSetDescription(edm::ParameterSetDescription &desc, bool allowRetry=true)
Definition: SonicClientBase.cc:78
MessageLogger.h
TritonClient::client_
std::unique_ptr< nvidia::inferenceserver::client::InferenceServerGrpcClient > client_
Definition: TritonClient.h:65
TritonClient::outputsTriton_
std::vector< const nvidia::inferenceserver::client::InferRequestedOutput * > outputsTriton_
Definition: TritonClient.h:63
submitPVValidationJobs.now
now
Definition: submitPVValidationJobs.py:639
mps_update.status
status
Definition: mps_update.py:69
CalibrationSummaryClient_cfi.params
params
Definition: CalibrationSummaryClient_cfi.py:14
SonicClientTypes< TritonInputMap, TritonOutputMap >::output_
Output output_
Definition: SonicClientTypes.h:20
TritonClient.h
SonicClientTypes< TritonInputMap, TritonOutputMap >::output
const Output & output() const
Definition: SonicClientTypes.h:16
TritonClient::ServerSideStats
Definition: TritonClient.h:20
TritonClient::summarizeServerStats
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
Definition: TritonClient.cc:314
edm::ParameterSetDescription
Definition: ParameterSetDescription.h:52
if
if(0==first)
Definition: CAHitNtupletGeneratorKernelsImpl.h:48
TritonClient::setBatchSize
bool setBatchSize(unsigned bsize)
Definition: TritonClient.cc:151
SonicClientBase::finish
void finish(bool success, std::exception_ptr eptr=std::exception_ptr{})
Definition: SonicClientBase.cc:49
bookConverter.results
results
Definition: bookConverter.py:144
mps_check.msg
tuple msg
Definition: mps_check.py:285
edm::LogInfo
Log< level::Info, false > LogInfo
Definition: MessageLogger.h:125
triton_utils::warnIfError
bool warnIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:25
edm::LogWarning
Log< level::Warning, false > LogWarning
Definition: MessageLogger.h:122
FileInPath.h
edm::FileInPath
Definition: FileInPath.h:64
triton_utils::throwIfError
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:20
RandomServiceHelper.t1
t1
Definition: RandomServiceHelper.py:256
TritonClient::ServerSideStats::compute_infer_time_ns_
uint64_t compute_infer_time_ns_
Definition: TritonClient.h:27
HLTObjectMonitor_Client_cff.client
client
Definition: HLTObjectMonitor_Client_cff.py:6
TritonClient::ServerSideStats::compute_input_time_ns_
uint64_t compute_input_time_ns_
Definition: TritonClient.h:26
Service.h
summarizeEdmComparisonLogfiles.success
success
Definition: summarizeEdmComparisonLogfiles.py:115
submitPVResolutionJobs.count
count
Definition: submitPVResolutionJobs.py:352
verbose
static constexpr int verbose
Definition: HLTExoticaSubAnalysis.cc:25
TritonClient::inputsTriton_
std::vector< nvidia::inferenceserver::client::InferInput * > inputsTriton_
Definition: TritonClient.h:62
TritonService.h
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
SonicMode::Sync
TritonClient::TritonClient
TritonClient(const edm::ParameterSet &params, const std::string &debugName)
Definition: TritonClient.cc:26
edm::ParameterSetDescription::addUntracked
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
Definition: ParameterSetDescription.h:100
TritonClient::verbose_
bool verbose_
Definition: TritonClient.h:59
SonicClientBase::mode_
SonicMode mode_
Definition: SonicClientBase.h:55
edm::ParameterSet
Definition: ParameterSet.h:47
SonicMode::Async
TritonClient::reportServerSideStats
void reportServerSideStats(const ServerSideStats &stats) const
Definition: TritonClient.cc:277
SonicClientBase::fullDebugName_
std::string fullDebugName_
Definition: SonicClientBase.h:61
SiStripPI::max
Definition: SiStripPayloadInspectorHelper.h:169
TritonClient::ServerSideStats::execution_count_
uint64_t execution_count_
Definition: TritonClient.h:22
TritonClient::ServerSideStats::success_count_
uint64_t success_count_
Definition: TritonClient.h:23
TritonClient::getResults
bool getResults(std::shared_ptr< nvidia::inferenceserver::client::InferResult > results)
Definition: TritonClient.cc:178
SonicClientBase::debugName_
std::string debugName_
Definition: SonicClientBase.h:61
edm::Service
Definition: Service.h:30
SonicClient
Definition: SonicClient.h:9
dqmMemoryStats.stats
stats
Definition: dqmMemoryStats.py:134
TritonClient::ServerSideStats::compute_output_time_ns_
uint64_t compute_output_time_ns_
Definition: TritonClient.h:28
TritonClient::batchSize_
unsigned batchSize_
Definition: TritonClient.h:57
std
Definition: JetResolutionObject.h:76
TritonClient::ServerSideStats::queue_time_ns_
uint64_t queue_time_ns_
Definition: TritonClient.h:25
TritonClient::fillPSetDescription
static void fillPSetDescription(edm::ParameterSetDescription &iDesc)
Definition: TritonClient.cc:348
triton_utils.h
TritonClient::verbose
bool verbose() const
Definition: TritonClient.h:36
TritonClient::options_
nvidia::inferenceserver::client::InferOptions options_
Definition: TritonClient.h:67
Exception.h
cond::uint64_t
unsigned long long uint64_t
Definition: Time.h:13
cms::Exception
Definition: Exception.h:70
TritonClient::getServerSideStatus
inference::ModelStatistics getServerSideStatus() const
Definition: TritonClient.cc:335
TritonClient::reset
void reset() override
Definition: TritonClient.cc:169
TritonClient::ServerSideStats::cumm_time_ns_
uint64_t cumm_time_ns_
Definition: TritonClient.h:24
edm::Log
Definition: MessageLogger.h:70
heppy_report.oname
oname
Definition: heppy_report.py:58
triton_utils::printColl
std::string printColl(const C &coll, const std::string &delim=", ")
Definition: triton_utils.cc:11
SonicClientBase::setMode
void setMode(SonicMode mode)
Definition: SonicClientBase.cc:26