11 #include "grpc_client.h" 12 #include "grpc_service.pb.h" 13 #include "model_config.pb.h" 15 #include "google/protobuf/text_format.h" 16 #include "google/protobuf/io/zero_copy_stream_impl.h" 21 #include <experimental/iterator> 31 grpc_compression_algorithm getCompressionAlgo(
const std::string&
name) {
33 return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
34 else if (
name.compare(
"deflate") == 0)
35 return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
36 else if (
name.compare(
"gzip") == 0)
37 return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
40 <<
"Unknown compression algorithm requested: " <<
name <<
" (choices: none, deflate, gzip)";
43 std::vector<std::shared_ptr<tc::InferResult>> convertToShared(
const std::vector<tc::InferResult*>&
tmp) {
44 std::vector<std::shared_ptr<tc::InferResult>>
results;
47 return std::shared_ptr<tc::InferResult>(ptr);
59 manualBatchMode_(
false),
60 verbose_(
params.getUntrackedParameter<
bool>(
"verbose")),
61 useSharedMemory_(
params.getUntrackedParameter<
bool>(
"useSharedMemory")),
62 compressionAlgo_(getCompressionAlgo(
params.getUntrackedParameter<
std::
string>(
"compression"))) {
80 "TritonClient(): unable to create inference context",
85 options_[0].client_timeout_ =
params.getUntrackedParameter<
unsigned>(
"timeout");
87 const auto& timeoutUnit =
params.getUntrackedParameter<
std::string>(
"timeoutUnit");
89 if (timeoutUnit ==
"seconds")
91 else if (timeoutUnit ==
"milliseconds")
93 else if (timeoutUnit ==
"microseconds")
96 throw cms::Exception(
"Configuration") <<
"Unknown timeout unit: " << timeoutUnit;
100 inference::ModelConfig localModelConfig;
103 int fileDescriptor = open(localModelConfigPath.c_str(), O_RDONLY);
104 if (fileDescriptor < 0)
106 <<
"TritonClient(): unable to open local model config: " << localModelConfigPath;
107 google::protobuf::io::FileInputStream localModelConfigInput(fileDescriptor);
108 localModelConfigInput.SetCloseOnDelete(
true);
109 if (!google::protobuf::TextFormat::Parse(&localModelConfigInput, &localModelConfig))
111 <<
"TritonClient(): unable to parse local model config: " << localModelConfigPath;
125 inference::ModelConfigResponse modelConfigResponse;
127 "TritonClient(): unable to get model config",
129 inference::ModelConfig remoteModelConfig(modelConfigResponse.config());
131 std::map<std::string, std::array<std::string, 2>> checksums;
132 size_t fileCounter = 0;
133 for (
const auto& modelConfig : {localModelConfig, remoteModelConfig}) {
134 const auto& agents = modelConfig.model_repository_agents().agents();
135 auto agent = std::find_if(agents.begin(), agents.end(), [](
auto const&
a) {
return a.name() ==
"checksum"; });
136 if (agent != agents.end()) {
137 const auto&
params = agent->parameters();
140 if (
key.compare(0,
options_[0].model_version_.size() + 1,
options_[0].model_version_ +
"/") == 0)
141 checksums[
key][fileCounter] =
val;
146 std::vector<std::string> incorrect;
147 for (
const auto& [
key,
val] : checksums) {
148 if (checksums[
key][0] != checksums[
key][1])
149 incorrect.push_back(
key);
151 if (!incorrect.empty())
152 throw TritonException(
"ModelVersioning") <<
"The following files have incorrect checksums on the remote server: " 156 inference::ModelMetadataResponse modelMetadata;
158 "TritonClient(): unable to get model metadata",
162 const auto& nicInputs = modelMetadata.inputs();
163 const auto& nicOutputs = modelMetadata.outputs();
166 std::stringstream
msg;
170 if (nicInputs.empty())
171 msg <<
"Model on server appears malformed (zero inputs)\n";
173 if (nicOutputs.empty())
174 msg <<
"Model on server appears malformed (zero outputs)\n";
178 if (!msg_str.empty())
182 std::stringstream io_msg;
184 io_msg <<
"Model inputs: " 186 for (
const auto& nicInput : nicInputs) {
187 const auto& iname = nicInput.name();
188 auto [curr_itr,
success] =
input_.emplace(std::piecewise_construct,
189 std::forward_as_tuple(iname),
190 std::forward_as_tuple(iname, nicInput,
this, ts->
pid()));
191 auto& curr_input = curr_itr->second;
193 io_msg <<
" " << iname <<
" (" << curr_input.dname() <<
", " << curr_input.byteSize()
199 const auto& v_outputs =
params.getUntrackedParameter<std::vector<std::string>>(
"outputs");
200 std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
204 io_msg <<
"Model outputs: " 206 for (
const auto& nicOutput : nicOutputs) {
207 const auto&
oname = nicOutput.name();
208 if (!s_outputs.empty() and s_outputs.find(
oname) == s_outputs.end())
210 auto [curr_itr,
success] =
output_.emplace(std::piecewise_construct,
211 std::forward_as_tuple(
oname),
212 std::forward_as_tuple(
oname, nicOutput,
this, ts->
pid()));
213 auto& curr_output = curr_itr->second;
215 io_msg <<
" " <<
oname <<
" (" << curr_output.dname() <<
", " << curr_output.byteSize()
218 if (!s_outputs.empty())
219 s_outputs.erase(
oname);
223 if (!s_outputs.empty())
228 std::stringstream model_msg;
230 model_msg <<
"Model name: " <<
options_[0].model_name_ <<
"\n" 231 <<
"Model version: " <<
options_[0].model_version_ <<
"\n" 288 for (
auto& element :
input_) {
289 element.second.entries_.resize(
entry);
291 for (
auto& element :
output_) {
292 element.second.entries_.resize(
entry);
298 for (
auto& element :
input_) {
299 element.second.addEntryImpl(
entry);
301 for (
auto& element :
output_) {
302 element.second.addEntryImpl(
entry);
313 for (
auto& element :
input_) {
314 element.second.reset();
316 for (
auto& element :
output_) {
317 element.second.reset();
321 template <
typename F>
330 e.convertToWarning();
336 finish(
false, std::current_exception());
342 for (
unsigned i = 0;
i <
results.size(); ++
i) {
346 if (
output.variableDims()) {
347 std::vector<int64_t> tmp_shape;
349 result->Shape(
oname, &tmp_shape),
"getResults(): unable to get output shape for " +
oname,
false);
351 tmp_shape.erase(tmp_shape.begin());
374 std::vector<std::shared_ptr<tc::InferResult>> empty_results;
383 std::vector<std::vector<triton::client::InferInput*>> inputsTriton(nEntriesVal);
384 for (
auto& inputTriton : inputsTriton) {
385 inputTriton.reserve(
input_.size());
388 for (
unsigned i = 0;
i < nEntriesVal; ++
i) {
389 inputsTriton[
i].push_back(
input.data(
i));
394 std::vector<std::vector<const triton::client::InferRequestedOutput*>> outputsTriton(nEntriesVal);
395 for (
auto& outputTriton : outputsTriton) {
396 outputTriton.reserve(
output_.size());
399 for (
unsigned i = 0;
i < nEntriesVal; ++
i) {
400 outputsTriton[
i].push_back(
output.data(
i));
406 for (
auto& element :
output_) {
407 element.second.prepare();
414 inference::ModelStatistics start_status;
426 [start_status,
this](std::vector<tc::InferResult*> resultsTmp) {
428 const auto& results = convertToShared(resultsTmp);
430 for (auto ptr : results) {
431 auto success = handle_exception([&]() {
432 TRITON_THROW_IF_ERROR(
433 ptr->RequestStatus(),
"evaluate(): unable to get result(s)", isLocal_);
440 inference::ModelStatistics end_status;
462 "evaluate(): unable to launch async run",
469 std::vector<tc::InferResult*> resultsTmp;
470 success = handle_exception([&]() {
472 client_->InferMulti(&resultsTmp, options_, inputsTriton, outputsTriton, headers_, compressionAlgo_),
473 "evaluate(): unable to run and/or get result",
477 const auto&
results = convertToShared(resultsTmp);
482 inference::ModelStatistics end_status;
483 success = handle_exception([&]() { end_status = getServerSideStatus(); });
487 const auto&
stats = summarizeServerStats(start_status, end_status);
488 reportServerSideStats(
stats);
500 std::stringstream
msg;
504 msg <<
" Inference count: " <<
stats.inference_count_ <<
"\n";
505 msg <<
" Execution count: " <<
stats.execution_count_ <<
"\n";
506 msg <<
" Successful request count: " <<
count <<
"\n";
511 return tval / us_to_ns /
count;
514 const uint64_t cumm_avg_us = get_avg_us(
stats.cumm_time_ns_);
515 const uint64_t queue_avg_us = get_avg_us(
stats.queue_time_ns_);
516 const uint64_t compute_input_avg_us = get_avg_us(
stats.compute_input_time_ns_);
517 const uint64_t compute_infer_avg_us = get_avg_us(
stats.compute_infer_time_ns_);
518 const uint64_t compute_output_avg_us = get_avg_us(
stats.compute_output_time_ns_);
519 const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
521 (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
523 msg <<
" Avg request latency: " << cumm_avg_us <<
" usec" 525 <<
" (overhead " << overhead <<
" usec + " 526 <<
"queue " << queue_avg_us <<
" usec + " 527 <<
"compute input " << compute_input_avg_us <<
" usec + " 528 <<
"compute infer " << compute_infer_avg_us <<
" usec + " 529 <<
"compute output " << compute_output_avg_us <<
" usec)" << std::endl;
537 const inference::ModelStatistics& end_status)
const {
540 server_stats.
inference_count_ = end_status.inference_count() - start_status.inference_count();
541 server_stats.
execution_count_ = end_status.execution_count() - start_status.execution_count();
543 end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
545 end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
546 server_stats.
queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
548 end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
550 end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
552 end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
559 inference::ModelStatisticsResponse resp;
561 "getServerSideStatus(): unable to get model statistics",
563 return *(resp.model_stats().begin());
565 return inference::ModelStatistics{};
579 edm::allowedValues<std::string>(
"seconds",
"milliseconds",
"microseconds"));
582 descClient.
addUntracked<std::vector<std::string>>(
"outputs", {});
uint64_t execution_count_
ParameterDescriptionNode * ifValue(ParameterDescription< T > const &switchParameter, std::unique_ptr< ParameterDescriptionCases< T >> cases)
void getResults(const std::vector< std::shared_ptr< triton::client::InferResult >> &results)
const std::string & pid() const
bool setBatchSize(unsigned bsize)
ParameterDescriptionBase * addUntracked(U const &iLabel, T const &value)
std::string fullPath() const
void notifyCallStatus(bool status) const
void addEntry(unsigned entry)
TritonBatchMode batchMode() const
void setMode(SonicMode mode)
std::unique_ptr< triton::client::InferenceServerGrpcClient > client_
TritonClient(const edm::ParameterSet ¶ms, const std::string &debugName)
uint64_t compute_infer_time_ns_
uint64_t inference_count_
void finish(bool success, std::exception_ptr eptr=std::exception_ptr{})
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
TritonServerType serverType_
The Signals That Services Can Subscribe To This is based on ActivityRegistry and is current per Services can connect to the signals distributed by the ActivityRegistry in order to monitor the activity of the application Each possible callback has some defined which we here list in angle e< void, edm::EventID const &, edm::Timestamp const & > We also list in braces which AR_WATCH_USING_METHOD_ is used for those or
bool handle_exception(F &&call)
key
prepare the HTCondor submission files and eventually submit them
grpc_compression_algorithm compressionAlgo_
#define TRITON_THROW_IF_ERROR(X, MSG, NOTIFY)
ParameterDescriptionBase * add(U const &iLabel, T const &value)
static void fillBasePSetDescription(edm::ParameterSetDescription &desc, bool allowRetry=true)
void resizeEntries(unsigned entry)
inference::ModelStatistics getServerSideStatus() const
Log< level::Info, false > LogInfo
triton::client::Headers headers_
unsigned nEntries() const
void conversion(EventAux const &from, EventAuxiliary &to)
unsigned long long uint64_t
unsigned batchSize() const
void setBatchMode(TritonBatchMode batchMode)
Server serverInfo(const std::string &model, const std::string &preferred="") const
void reportServerSideStats(const ServerSideStats &stats) const
std::string fullDebugName_
uint64_t compute_output_time_ns_
static void fillPSetDescription(edm::ParameterSetDescription &iDesc)
Log< level::Warning, false > LogWarning
static uInt32 F(BLOWFISH_CTX *ctx, uInt32 x)
TritonBatchMode batchMode_
std::string printColl(const C &coll, const std::string &delim=", ")
std::vector< triton::client::InferOptions > options_
if(threadIdxLocalY==0 &&threadIdxLocalX==0)
uint64_t compute_input_time_ns_