10 #include "grpc_client.h" 11 #include "grpc_service.pb.h" 23 grpc_compression_algorithm getCompressionAlgo(
const std::string&
name) {
25 return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
26 else if (
name.compare(
"deflate") == 0)
27 return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
28 else if (
name.compare(
"gzip") == 0)
29 return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
32 <<
"Unknown compression algorithm requested: " <<
name <<
" (choices: none, deflate, gzip)";
41 verbose_(
params.getUntrackedParameter<
bool>(
"verbose")),
42 useSharedMemory_(
params.getUntrackedParameter<
bool>(
"useSharedMemory")),
43 compressionAlgo_(getCompressionAlgo(
params.getUntrackedParameter<
std::
string>(
"compression"))),
60 "TritonClient(): unable to create inference context");
65 options_.client_timeout_ =
params.getUntrackedParameter<
unsigned>(
"timeout") * 1e6;
68 inference::ModelConfigResponse modelConfigResponse;
70 "TritonClient(): unable to get model config");
71 inference::ModelConfig modelConfig(modelConfigResponse.config());
82 inference::ModelMetadataResponse modelMetadata;
84 "TritonClient(): unable to get model metadata");
87 const auto& nicInputs = modelMetadata.inputs();
88 const auto& nicOutputs = modelMetadata.outputs();
91 std::stringstream
msg;
95 if (nicInputs.empty())
96 msg <<
"Model on server appears malformed (zero inputs)\n";
98 if (nicOutputs.empty())
99 msg <<
"Model on server appears malformed (zero outputs)\n";
103 if (!msg_str.empty())
107 std::stringstream io_msg;
109 io_msg <<
"Model inputs: " 112 for (
const auto& nicInput : nicInputs) {
113 const auto& iname = nicInput.name();
114 auto [curr_itr,
success] =
input_.emplace(std::piecewise_construct,
115 std::forward_as_tuple(iname),
116 std::forward_as_tuple(iname, nicInput,
this, ts->
pid()));
117 auto& curr_input = curr_itr->second;
120 io_msg <<
" " << iname <<
" (" << curr_input.dname() <<
", " << curr_input.byteSize()
126 const auto& v_outputs =
params.getUntrackedParameter<std::vector<std::string>>(
"outputs");
127 std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
131 io_msg <<
"Model outputs: " 134 for (
const auto& nicOutput : nicOutputs) {
135 const auto&
oname = nicOutput.name();
136 if (!s_outputs.empty() and s_outputs.find(
oname) == s_outputs.end())
138 auto [curr_itr,
success] =
output_.emplace(std::piecewise_construct,
139 std::forward_as_tuple(
oname),
140 std::forward_as_tuple(
oname, nicOutput,
this, ts->
pid()));
141 auto& curr_output = curr_itr->second;
144 io_msg <<
" " <<
oname <<
" (" << curr_output.dname() <<
", " << curr_output.byteSize()
147 if (!s_outputs.empty())
148 s_outputs.erase(
oname);
152 if (!s_outputs.empty())
160 std::stringstream model_msg;
162 model_msg <<
"Model name: " <<
options_.model_name_ <<
"\n" 163 <<
"Model version: " <<
options_.model_version_ <<
"\n" 186 for (
auto& element :
input_) {
187 element.second.setBatchSize(bsize);
189 for (
auto& element :
output_) {
190 element.second.setBatchSize(bsize);
197 for (
auto& element :
input_) {
198 element.second.reset();
200 for (
auto& element :
output_) {
201 element.second.reset();
205 template <
typename F>
214 e.convertToWarning();
220 finish(
false, std::current_exception());
228 if (
output.variableDims()) {
229 std::vector<int64_t> tmp_shape;
232 tmp_shape.erase(tmp_shape.begin());
233 output.setShape(tmp_shape);
251 for (
auto& element :
output_) {
252 element.second.prepare();
259 inference::ModelStatistics start_status;
272 [start_status,
this](tc::InferResult*
results) {
274 std::shared_ptr<tc::InferResult> results_ptr(results);
275 auto success = handle_exception(
276 [&]() { TRITON_THROW_IF_ERROR(results_ptr->RequestStatus(),
"evaluate(): unable to get result"); });
281 inference::ModelStatistics end_status;
303 "evaluate(): unable to launch async run");
310 success = handle_exception([&]() {
312 client_->Infer(&
results, options_, inputsTriton_, outputsTriton_, headers_, compressionAlgo_),
313 "evaluate(): unable to run and/or get result");
319 inference::ModelStatistics end_status;
320 success = handle_exception([&]() { end_status = getServerSideStatus(); });
324 const auto&
stats = summarizeServerStats(start_status, end_status);
325 reportServerSideStats(
stats);
328 std::shared_ptr<tc::InferResult> results_ptr(
results);
329 success = handle_exception([&]() { getResults(results_ptr); });
338 std::stringstream
msg;
342 msg <<
" Inference count: " <<
stats.inference_count_ <<
"\n";
343 msg <<
" Execution count: " <<
stats.execution_count_ <<
"\n";
344 msg <<
" Successful request count: " <<
count <<
"\n";
349 return tval / us_to_ns /
count;
352 const uint64_t cumm_avg_us = get_avg_us(
stats.cumm_time_ns_);
353 const uint64_t queue_avg_us = get_avg_us(
stats.queue_time_ns_);
354 const uint64_t compute_input_avg_us = get_avg_us(
stats.compute_input_time_ns_);
355 const uint64_t compute_infer_avg_us = get_avg_us(
stats.compute_infer_time_ns_);
356 const uint64_t compute_output_avg_us = get_avg_us(
stats.compute_output_time_ns_);
357 const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
359 (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
361 msg <<
" Avg request latency: " << cumm_avg_us <<
" usec" 363 <<
" (overhead " << overhead <<
" usec + " 364 <<
"queue " << queue_avg_us <<
" usec + " 365 <<
"compute input " << compute_input_avg_us <<
" usec + " 366 <<
"compute infer " << compute_infer_avg_us <<
" usec + " 367 <<
"compute output " << compute_output_avg_us <<
" usec)" << std::endl;
375 const inference::ModelStatistics& end_status)
const {
378 server_stats.
inference_count_ = end_status.inference_count() - start_status.inference_count();
379 server_stats.
execution_count_ = end_status.execution_count() - start_status.execution_count();
381 end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
383 end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
384 server_stats.
queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
386 end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
388 end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
390 end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
397 inference::ModelStatisticsResponse resp;
399 "getServerSideStatus(): unable to get model statistics");
400 return *(resp.model_stats().begin());
402 return inference::ModelStatistics{};
417 descClient.
addUntracked<std::vector<std::string>>(
"outputs", {});
uint64_t execution_count_
const std::string & pid() const
bool setBatchSize(unsigned bsize)
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
#define TRITON_THROW_IF_ERROR(X, MSG)
std::vector< const triton::client::InferRequestedOutput * > outputsTriton_
void setMode(SonicMode mode)
std::unique_ptr< triton::client::InferenceServerGrpcClient > client_
TritonClient(const edm::ParameterSet ¶ms, const std::string &debugName)
uint64_t compute_infer_time_ns_
uint64_t inference_count_
void finish(bool success, std::exception_ptr eptr=std::exception_ptr{})
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
TritonServerType serverType_
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e< void, edm::EventID const &, edm::Timestamp const & > We also list in braces which AR_WATCH_USING_METHOD_ is used for those or
bool handle_exception(F &&call)
grpc_compression_algorithm compressionAlgo_
ParameterDescriptionBase * add(U const &iLabel, T const &value)
const Output & output() const
static void fillBasePSetDescription(edm::ParameterSetDescription &desc, bool allowRetry=true)
inference::ModelStatistics getServerSideStatus() const
Log< level::Info, false > LogInfo
triton::client::Headers headers_
unsigned long long uint64_t
std::vector< triton::client::InferInput * > inputsTriton_
triton::client::InferOptions options_
Server serverInfo(const std::string &model, const std::string &preferred="") const
void getResults(std::shared_ptr< triton::client::InferResult > results)
void reportServerSideStats(const ServerSideStats &stats) const
std::string fullDebugName_
uint64_t compute_output_time_ns_
static void fillPSetDescription(edm::ParameterSetDescription &iDesc)
Log< level::Warning, false > LogWarning
static uInt32 F(BLOWFISH_CTX *ctx, uInt32 x)
std::string printColl(const C &coll, const std::string &delim=", ")
uint64_t compute_input_time_ns_