CMS 3D CMS Logo

Wrapper.cc
Go to the documentation of this file.
1 /*
2  * AOT wrapper interface for interacting with models compiled for different batch sizes.
3  *
4  * Author: Marcel Rieger, Bogdan Wiederspan
5  */
6 
7 #include <vector>
8 #include <map>
9 
11 
12 namespace tfaot {
13 
14  int Wrapper::argCount(size_t batchSize, size_t argIndex) const {
15  const auto& counts = argCounts();
16  const auto it = counts.find(batchSize);
17  if (it == counts.end()) {
18  unknownBatchSize(batchSize, "argCount()");
19  }
20  if (argIndex >= it->second.size()) {
21  unknownArgument(argIndex, "argCount()");
22  }
23  return it->second.at(argIndex);
24  }
25 
26  int Wrapper::argCountNoBatch(size_t argIndex) const {
27  const auto& counts = argCountsNoBatch();
28  if (argIndex >= counts.size()) {
29  unknownArgument(argIndex, "argCountNoBatch()");
30  }
31  return counts.at(argIndex);
32  }
33 
34  int Wrapper::resultCount(size_t batchSize, size_t resultIndex) const {
35  const auto& counts = resultCounts();
36  const auto it = counts.find(batchSize);
37  if (it == counts.end()) {
38  unknownBatchSize(batchSize, "resultCount()");
39  }
40  if (resultIndex >= it->second.size()) {
41  unknownResult(resultIndex, "resultCount()");
42  }
43  return it->second.at(resultIndex);
44  }
45 
46  int Wrapper::resultCountNoBatch(size_t resultIndex) const {
47  const auto& counts = resultCountsNoBatch();
48  if (resultIndex >= counts.size()) {
49  unknownResult(resultIndex, "resultCountNoBatch()");
50  }
51  return counts[resultIndex];
52  }
53 
54  void Wrapper::run(size_t batchSize) {
55  if (!runSilent(batchSize)) {
56  throw cms::Exception("FailedRun") << "evaluation with batch size " << batchSize << " failed for model '" << name_;
57  }
58  }
59 
60 } // namespace tfaot
void unknownResult(size_t resultIndex, const std::string &method) const
Definition: Wrapper.h:136
int resultCount(size_t batchSize, size_t resultIndex) const
Definition: Wrapper.cc:34
virtual const std::map< size_t, std::vector< size_t > > & argCounts() const =0
void run(size_t batchSize)
Definition: Wrapper.cc:54
virtual const std::vector< size_t > & resultCountsNoBatch() const =0
int argCount(size_t batchSize, size_t argIndex) const
Definition: Wrapper.cc:14
void unknownBatchSize(size_t batchSize, const std::string &method) const
Definition: Wrapper.h:124
Definition: Batching.h:15
virtual bool runSilent(size_t batchSize)=0
virtual const std::map< size_t, std::vector< size_t > > & resultCounts() const =0
virtual const std::vector< size_t > & argCountsNoBatch() const =0
std::string name_
Definition: Wrapper.h:142
int argCountNoBatch(size_t argIndex) const
Definition: Wrapper.cc:26
void unknownArgument(size_t argIndex, const std::string &method) const
Definition: Wrapper.h:130
int resultCountNoBatch(size_t resultIndex) const
Definition: Wrapper.cc:46