CMS 3D CMS Logo

Batching.h
Go to the documentation of this file.
1 #ifndef PHYSICSTOOLS_TENSORFLOWAOT_BATCHING_H
2 #define PHYSICSTOOLS_TENSORFLOWAOT_BATCHING_H
3 
4 /*
5  * AOT batching rules and strategies.
6  *
7  * Author: Marcel Rieger, Bogdan Wiederspan
8  */
9 
10 #include <cstddef>
11 #include <vector>
12 #include <map>
13 #include <ostream>
14 
15 namespace tfaot {
16 
17  // rule defining how a certain batch size should be composed of various smaller sizes plus an
18  // optional padding that is applied to the last size
19  class BatchRule {
20  public:
21  // constructor
22  explicit BatchRule(size_t batchSize, const std::vector<size_t>& sizes, size_t lastPadding = 0);
23 
24  // constructor taking a string in the format "batchSize:size1,...,sizeN" with lastPadding being
25  // inferred from the sum of sizes
26  BatchRule(const std::string& ruleString);
27 
28  // destructor
29  ~BatchRule() = default;
30 
31  // getter for the batch size
32  size_t getBatchSize() const { return batchSize_; }
33 
34  // getter for available sizes
35  const std::vector<size_t>& getSizes() const { return sizes_; }
36 
37  // getter for the last padding value
38  size_t getLastPadding() const { return lastPadding_; }
39 
40  // returns the number of available sizes
41  size_t nSizes() const { return sizes_.size(); }
42 
43  // getter for the registered size at index i
44  size_t getSize(size_t i) const { return sizes_[i]; }
45 
46  private:
47  size_t batchSize_;
48  std::vector<size_t> sizes_;
49  size_t lastPadding_;
50 
51  // validation helper
52  void validate() const;
53  };
54 
55  // stream operator
56  std::ostream& operator<<(std::ostream& out, const BatchRule& rule);
57 
58  // the batch strategy is a collection of batch rules registered to certain batch sizes
59  class BatchStrategy {
60  public:
61  // constructor
62  explicit BatchStrategy() = default;
63 
64  // destructor
65  ~BatchStrategy() = default;
66 
67  // registers a new rule for a batch size
68  void setRule(const BatchRule& rule) { rules_.insert_or_assign(rule.getBatchSize(), rule); }
69 
70  // registers a new rule for a batch size, given a rule string (see BatchRule constructor)
71  void setRule(const std::string& ruleString) { this->setRule(BatchRule(ruleString)); }
72 
73  // returns whether a rule was already registered for a certain batch size
74  bool hasRule(size_t batchSize) const { return rules_.find(batchSize) != rules_.end(); }
75 
76  // returns a rule registered previously for a certain batch size
77  const BatchRule& getRule(size_t batchSize) const;
78 
79  // registers a new rule for a certain batch size according to a certain algorithm
80  void setDefaultRule(size_t batchSize, const std::vector<size_t>& availableBatchSizes);
81 
82  private:
83  std::map<size_t, BatchRule> rules_;
84  };
85 
86 } // namespace tfaot
87 
88 #endif // PHYSICSTOOLS_TENSORFLOWAOT_BATCHING_H
const BatchRule & getRule(size_t batchSize) const
Definition: Batching.cc:85
size_t lastPadding_
Definition: Batching.h:49
void setRule(const BatchRule &rule)
Definition: Batching.h:68
~BatchStrategy()=default
Definition: Batching.h:15
BatchRule(size_t batchSize, const std::vector< size_t > &sizes, size_t lastPadding=0)
Definition: Batching.cc:16
size_t getLastPadding() const
Definition: Batching.h:38
void setRule(const std::string &ruleString)
Definition: Batching.h:71
~BatchRule()=default
size_t getBatchSize() const
Definition: Batching.h:32
bool hasRule(size_t batchSize) const
Definition: Batching.h:74
size_t nSizes() const
Definition: Batching.h:41
std::ostream & operator<<(std::ostream &out, const BatchRule &rule)
Definition: Batching.cc:93
const std::vector< size_t > & getSizes() const
Definition: Batching.h:35
void validate() const
Definition: Batching.cc:54
size_t getSize(size_t i) const
Definition: Batching.h:44
std::vector< size_t > sizes_
Definition: Batching.h:48
std::map< size_t, BatchRule > rules_
Definition: Batching.h:83
size_t batchSize_
Definition: Batching.h:47
void setDefaultRule(size_t batchSize, const std::vector< size_t > &availableBatchSizes)
Definition: Batching.cc:101