10 #include "grpc_client.h" 11 #include "grpc_service.pb.h" 23 grpc_compression_algorithm getCompressionAlgo(
const std::string&
name) {
25 return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
26 else if (
name.compare(
"deflate") == 0)
27 return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
28 else if (
name.compare(
"gzip") == 0)
29 return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
32 <<
"Unknown compression algorithm requested: " <<
name <<
" (choices: none, deflate, gzip)";
35 std::vector<std::shared_ptr<tc::InferResult>> convertToShared(
const std::vector<tc::InferResult*>&
tmp) {
36 std::vector<std::shared_ptr<tc::InferResult>>
results;
39 return std::shared_ptr<tc::InferResult>(ptr);
51 manualBatchMode_(
false),
52 verbose_(
params.getUntrackedParameter<
bool>(
"verbose")),
53 useSharedMemory_(
params.getUntrackedParameter<
bool>(
"useSharedMemory")),
54 compressionAlgo_(getCompressionAlgo(
params.getUntrackedParameter<
std::
string>(
"compression"))) {
71 "TritonClient(): unable to create inference context");
76 options_[0].client_timeout_ =
params.getUntrackedParameter<
unsigned>(
"timeout") * 1e6;
79 inference::ModelConfigResponse modelConfigResponse;
81 "TritonClient(): unable to get model config");
82 inference::ModelConfig modelConfig(modelConfigResponse.config());
95 inference::ModelMetadataResponse modelMetadata;
97 "TritonClient(): unable to get model metadata");
100 const auto& nicInputs = modelMetadata.inputs();
101 const auto& nicOutputs = modelMetadata.outputs();
104 std::stringstream
msg;
108 if (nicInputs.empty())
109 msg <<
"Model on server appears malformed (zero inputs)\n";
111 if (nicOutputs.empty())
112 msg <<
"Model on server appears malformed (zero outputs)\n";
116 if (!msg_str.empty())
120 std::stringstream io_msg;
122 io_msg <<
"Model inputs: " 124 for (
const auto& nicInput : nicInputs) {
125 const auto& iname = nicInput.name();
126 auto [curr_itr,
success] =
input_.emplace(std::piecewise_construct,
127 std::forward_as_tuple(iname),
128 std::forward_as_tuple(iname, nicInput,
this, ts->
pid()));
129 auto& curr_input = curr_itr->second;
131 io_msg <<
" " << iname <<
" (" << curr_input.dname() <<
", " << curr_input.byteSize()
137 const auto& v_outputs =
params.getUntrackedParameter<std::vector<std::string>>(
"outputs");
138 std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
142 io_msg <<
"Model outputs: " 144 for (
const auto& nicOutput : nicOutputs) {
145 const auto&
oname = nicOutput.name();
146 if (!s_outputs.empty() and s_outputs.find(
oname) == s_outputs.end())
148 auto [curr_itr,
success] =
output_.emplace(std::piecewise_construct,
149 std::forward_as_tuple(
oname),
150 std::forward_as_tuple(
oname, nicOutput,
this, ts->
pid()));
151 auto& curr_output = curr_itr->second;
153 io_msg <<
" " <<
oname <<
" (" << curr_output.dname() <<
", " << curr_output.byteSize()
156 if (!s_outputs.empty())
157 s_outputs.erase(
oname);
161 if (!s_outputs.empty())
166 std::stringstream model_msg;
168 model_msg <<
"Model name: " <<
options_[0].model_name_ <<
"\n" 169 <<
"Model version: " <<
options_[0].model_version_ <<
"\n" 226 for (
auto& element :
input_) {
227 element.second.entries_.resize(
entry);
229 for (
auto& element :
output_) {
230 element.second.entries_.resize(
entry);
236 for (
auto& element :
input_) {
237 element.second.addEntryImpl(
entry);
239 for (
auto& element :
output_) {
240 element.second.addEntryImpl(
entry);
251 for (
auto& element :
input_) {
252 element.second.reset();
254 for (
auto& element :
output_) {
255 element.second.reset();
259 template <
typename F>
268 e.convertToWarning();
274 finish(
false, std::current_exception());
280 for (
unsigned i = 0;
i <
results.size(); ++
i) {
284 if (
output.variableDims()) {
285 std::vector<int64_t> tmp_shape;
287 "getResults(): unable to get output shape for " +
oname);
289 tmp_shape.erase(tmp_shape.begin());
306 std::vector<std::shared_ptr<tc::InferResult>> empty_results;
315 std::vector<std::vector<triton::client::InferInput*>> inputsTriton(nEntriesVal);
316 for (
auto& inputTriton : inputsTriton) {
317 inputTriton.reserve(
input_.size());
320 for (
unsigned i = 0;
i < nEntriesVal; ++
i) {
321 inputsTriton[
i].push_back(
input.data(
i));
326 std::vector<std::vector<const triton::client::InferRequestedOutput*>> outputsTriton(nEntriesVal);
327 for (
auto& outputTriton : outputsTriton) {
328 outputTriton.reserve(
output_.size());
331 for (
unsigned i = 0;
i < nEntriesVal; ++
i) {
332 outputsTriton[
i].push_back(
output.data(
i));
338 for (
auto& element :
output_) {
339 element.second.prepare();
346 inference::ModelStatistics start_status;
359 [start_status,
this](std::vector<tc::InferResult*> resultsTmp) {
361 const auto& results = convertToShared(resultsTmp);
363 for (auto ptr : results) {
364 auto success = handle_exception(
365 [&]() { TRITON_THROW_IF_ERROR(ptr->RequestStatus(),
"evaluate(): unable to get result(s)"); });
371 inference::ModelStatistics end_status;
393 "evaluate(): unable to launch async run");
399 std::vector<tc::InferResult*> resultsTmp;
400 success = handle_exception([&]() {
402 client_->InferMulti(&resultsTmp, options_, inputsTriton, outputsTriton, headers_, compressionAlgo_),
403 "evaluate(): unable to run and/or get result");
406 const auto&
results = convertToShared(resultsTmp);
411 inference::ModelStatistics end_status;
412 success = handle_exception([&]() { end_status = getServerSideStatus(); });
416 const auto&
stats = summarizeServerStats(start_status, end_status);
417 reportServerSideStats(
stats);
429 std::stringstream
msg;
433 msg <<
" Inference count: " <<
stats.inference_count_ <<
"\n";
434 msg <<
" Execution count: " <<
stats.execution_count_ <<
"\n";
435 msg <<
" Successful request count: " <<
count <<
"\n";
440 return tval / us_to_ns /
count;
443 const uint64_t cumm_avg_us = get_avg_us(
stats.cumm_time_ns_);
444 const uint64_t queue_avg_us = get_avg_us(
stats.queue_time_ns_);
445 const uint64_t compute_input_avg_us = get_avg_us(
stats.compute_input_time_ns_);
446 const uint64_t compute_infer_avg_us = get_avg_us(
stats.compute_infer_time_ns_);
447 const uint64_t compute_output_avg_us = get_avg_us(
stats.compute_output_time_ns_);
448 const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
450 (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
452 msg <<
" Avg request latency: " << cumm_avg_us <<
" usec" 454 <<
" (overhead " << overhead <<
" usec + " 455 <<
"queue " << queue_avg_us <<
" usec + " 456 <<
"compute input " << compute_input_avg_us <<
" usec + " 457 <<
"compute infer " << compute_infer_avg_us <<
" usec + " 458 <<
"compute output " << compute_output_avg_us <<
" usec)" << std::endl;
466 const inference::ModelStatistics& end_status)
const {
469 server_stats.
inference_count_ = end_status.inference_count() - start_status.inference_count();
470 server_stats.
execution_count_ = end_status.execution_count() - start_status.execution_count();
472 end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
474 end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
475 server_stats.
queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
477 end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
479 end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
481 end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
488 inference::ModelStatisticsResponse resp;
490 "getServerSideStatus(): unable to get model statistics");
491 return *(resp.model_stats().begin());
493 return inference::ModelStatistics{};
508 descClient.
addUntracked<std::vector<std::string>>(
"outputs", {});
uint64_t execution_count_
void getResults(const std::vector< std::shared_ptr< triton::client::InferResult >> &results)
const std::string & pid() const
bool setBatchSize(unsigned bsize)
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
void addEntry(unsigned entry)
#define TRITON_THROW_IF_ERROR(X, MSG)
TritonBatchMode batchMode() const
void setMode(SonicMode mode)
std::unique_ptr< triton::client::InferenceServerGrpcClient > client_
TritonClient(const edm::ParameterSet ¶ms, const std::string &debugName)
uint64_t compute_infer_time_ns_
uint64_t inference_count_
void finish(bool success, std::exception_ptr eptr=std::exception_ptr{})
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
TritonServerType serverType_
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e< void, edm::EventID const &, edm::Timestamp const & > We also list in braces which AR_WATCH_USING_METHOD_ is used for those or
bool handle_exception(F &&call)
grpc_compression_algorithm compressionAlgo_
ParameterDescriptionBase * add(U const &iLabel, T const &value)
const Output & output() const
static void fillBasePSetDescription(edm::ParameterSetDescription &desc, bool allowRetry=true)
void resizeEntries(unsigned entry)
inference::ModelStatistics getServerSideStatus() const
Log< level::Info, false > LogInfo
triton::client::Headers headers_
unsigned nEntries() const
unsigned long long uint64_t
unsigned batchSize() const
void setBatchMode(TritonBatchMode batchMode)
Server serverInfo(const std::string &model, const std::string &preferred="") const
void reportServerSideStats(const ServerSideStats &stats) const
std::string fullDebugName_
uint64_t compute_output_time_ns_
static void fillPSetDescription(edm::ParameterSetDescription &iDesc)
Log< level::Warning, false > LogWarning
static uInt32 F(BLOWFISH_CTX *ctx, uInt32 x)
TritonBatchMode batchMode_
std::string printColl(const C &coll, const std::string &delim=", ")
std::vector< triton::client::InferOptions > options_
uint64_t compute_input_time_ns_