10 #include "grpc_client.h"
11 #include "grpc_service.pb.h"
23 grpc_compression_algorithm getCompressionAlgo(
const std::string&
name) {
24 if (name.empty()
or name.compare(
"none") == 0)
25 return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
26 else if (name.compare(
"deflate") == 0)
27 return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
28 else if (name.compare(
"gzip") == 0)
29 return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
32 <<
"Unknown compression algorithm requested: " << name <<
" (choices: none, deflate, gzip)";
41 verbose_(params.getUntrackedParameter<bool>(
"verbose")),
42 useSharedMemory_(params.getUntrackedParameter<bool>(
"useSharedMemory")),
43 compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter<std::
string>(
"compression"))),
44 options_(params.getParameter<std::
string>(
"modelName")) {
60 "TritonClient(): unable to create inference context");
68 inference::ModelConfigResponse modelConfigResponse;
70 "TritonClient(): unable to get model config");
71 inference::ModelConfig modelConfig(modelConfigResponse.config());
79 maxBatchSize_ =
std::max(1u, maxBatchSize_);
82 inference::ModelMetadataResponse modelMetadata;
84 "TritonClient(): unable to get model metadata");
87 const auto& nicInputs = modelMetadata.inputs();
88 const auto& nicOutputs = modelMetadata.outputs();
91 std::stringstream
msg;
95 if (nicInputs.empty())
96 msg <<
"Model on server appears malformed (zero inputs)\n";
98 if (nicOutputs.empty())
99 msg <<
"Model on server appears malformed (zero outputs)\n";
103 if (!msg_str.empty())
107 std::stringstream io_msg;
109 io_msg <<
"Model inputs: "
112 for (
const auto& nicInput : nicInputs) {
113 const auto& iname = nicInput.name();
114 auto [curr_itr,
success] =
input_.emplace(std::piecewise_construct,
115 std::forward_as_tuple(iname),
116 std::forward_as_tuple(iname, nicInput,
this, ts->
pid()));
117 auto& curr_input = curr_itr->second;
120 io_msg <<
" " << iname <<
" (" << curr_input.dname() <<
", " << curr_input.byteSize()
127 std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
131 io_msg <<
"Model outputs: "
134 for (
const auto& nicOutput : nicOutputs) {
135 const auto&
oname = nicOutput.name();
136 if (!s_outputs.empty() and s_outputs.find(
oname) == s_outputs.end())
138 auto [curr_itr,
success] =
output_.emplace(std::piecewise_construct,
139 std::forward_as_tuple(
oname),
140 std::forward_as_tuple(
oname, nicOutput,
this, ts->
pid()));
141 auto& curr_output = curr_itr->second;
144 io_msg <<
" " <<
oname <<
" (" << curr_output.dname() <<
", " << curr_output.byteSize()
147 if (!s_outputs.empty())
148 s_outputs.erase(
oname);
152 if (!s_outputs.empty())
160 std::stringstream model_msg;
162 model_msg <<
"Model name: " <<
options_.model_name_ <<
"\n"
163 <<
"Model version: " <<
options_.model_version_ <<
"\n"
186 for (
auto& element :
input_) {
187 element.second.setBatchSize(bsize);
189 for (
auto& element :
output_) {
190 element.second.setBatchSize(bsize);
197 for (
auto& element :
input_) {
198 element.second.reset();
200 for (
auto& element :
output_) {
201 element.second.reset();
205 template <
typename F>
220 finish(
false, std::current_exception());
228 if (
output.variableDims()) {
229 std::vector<int64_t> tmp_shape;
231 "getResults(): unable to get output shape for " +
oname);
232 output.setShape(tmp_shape);
236 output.setResult(results);
250 for (
auto& element :
output_) {
251 element.second.prepare();
258 inference::ModelStatistics start_status;
270 [start_status,
this](tc::InferResult*
results) {
272 std::shared_ptr<tc::InferResult> results_ptr(results);
275 "evaluate(): unable to get result");
281 inference::ModelStatistics end_status;
303 "evaluate(): unable to launch async run");
310 success = handle_exception([&]() {
312 client_->Infer(&results, options_, inputsTriton_, outputsTriton_, headers_, compressionAlgo_),
313 "evaluate(): unable to run and/or get result");
319 inference::ModelStatistics end_status;
320 success = handle_exception([&]() { end_status = getServerSideStatus(); });
324 const auto&
stats = summarizeServerStats(start_status, end_status);
325 reportServerSideStats(
stats);
328 std::shared_ptr<tc::InferResult> results_ptr(results);
329 success = handle_exception([&]() { getResults(results_ptr); });
338 std::stringstream
msg;
344 msg <<
" Successful request count: " << count <<
"\n";
349 return tval / us_to_ns /
count;
357 const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
359 (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
361 msg <<
" Avg request latency: " << cumm_avg_us <<
" usec"
363 <<
" (overhead " << overhead <<
" usec + "
364 <<
"queue " << queue_avg_us <<
" usec + "
365 <<
"compute input " << compute_input_avg_us <<
" usec + "
366 <<
"compute infer " << compute_infer_avg_us <<
" usec + "
367 <<
"compute output " << compute_output_avg_us <<
" usec)" << std::endl;
375 const inference::ModelStatistics& end_status)
const {
378 server_stats.
inference_count_ = end_status.inference_count() - start_status.inference_count();
379 server_stats.
execution_count_ = end_status.execution_count() - start_status.execution_count();
381 end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
383 end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
384 server_stats.
queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
386 end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
388 end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
390 end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
397 inference::ModelStatisticsResponse resp;
399 "getServerSideStatus(): unable to get model statistics");
400 return *(resp.model_stats().begin());
402 return inference::ModelStatistics{};
418 descClient.
addUntracked<std::vector<std::string>>(
"outputs", {});
uint64_t execution_count_
T getUntrackedParameter(std::string const &, T const &) const
Server serverInfo(const std::string &model, const std::string &preferred="") const
bool setBatchSize(unsigned bsize)
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e< void, edm::EventIDconst &, edm::Timestampconst & > We also list in braces which AR_WATCH_USING_METHOD_ is used for those or
inference::ModelStatistics getServerSideStatus() const
std::vector< const triton::client::InferRequestedOutput * > outputsTriton_
void setMode(SonicMode mode)
std::unique_ptr< triton::client::InferenceServerGrpcClient > client_
TritonClient(const edm::ParameterSet ¶ms, const std::string &debugName)
uint64_t compute_infer_time_ns_
uint64_t inference_count_
void finish(bool success, std::exception_ptr eptr=std::exception_ptr{})
void throwIfError(const Error &err, std::string_view msg)
void reportServerSideStats(const ServerSideStats &stats) const
TritonServerType serverType_
static constexpr int verbose
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
bool handle_exception(F &&call)
grpc_compression_algorithm compressionAlgo_
ParameterDescriptionBase * add(U const &iLabel, T const &value)
static void fillBasePSetDescription(edm::ParameterSetDescription &desc, bool allowRetry=true)
Log< level::Info, false > LogInfo
triton::client::Headers headers_
unsigned long long uint64_t
std::vector< triton::client::InferInput * > inputsTriton_
T getParameter(std::string const &) const
void convertToWarning() const
triton::client::InferOptions options_
void getResults(std::shared_ptr< triton::client::InferResult > results)
std::string fullDebugName_
const std::string & pid() const
uint64_t compute_output_time_ns_
static void fillPSetDescription(edm::ParameterSetDescription &iDesc)
Log< level::Warning, false > LogWarning
static uInt32 F(BLOWFISH_CTX *ctx, uInt32 x)
std::string printColl(const C &coll, const std::string &delim=", ")
const Output & output() const
uint64_t compute_input_time_ns_