CMS 3D CMS Logo

List of all members | Public Member Functions | Private Member Functions | Private Attributes
tfaot::Model< W > Class Template Reference

#include <Model.h>

Public Member Functions

const BatchStrategygetBatchStrategy () const
 
 Model ()
 
const std::string & name () const
 
template<typename... Outputs, typename... Inputs>
std::tuple< Outputs... > run (size_t batchSize, Inputs &&... inputs)
 
void setBatchRule (size_t batchSize, const std::vector< size_t > &sizes, size_t lastPadding=0)
 
void setBatchRule (const std::string &batchRule)
 
void setBatchStrategy (const BatchStrategy &strategy)
 
 ~Model ()
 

Private Member Functions

const BatchRuleensureRule (size_t batchSize)
 
template<typename T >
void extractBatchOutput (size_t batchSize, size_t batchIndex, size_t resultIndex, std::vector< T > &batchData) const
 
template<typename T >
void injectBatchInput (size_t batchSize, size_t batchIndex, size_t argIndex, const std::vector< T > &batchData)
 
template<typename T >
void reserveOutput (size_t batchSize, size_t resultIndex, std::vector< std::vector< T >> &data) const
 

Private Attributes

BatchStrategy batchStrategy_
 
std::unique_ptr< W > wrapper_
 

Detailed Description

template<class W>
class tfaot::Model< W >

Definition at line 19 of file Model.h.

Constructor & Destructor Documentation

◆ Model()

template<class W >
tfaot::Model< W >::Model ( )
inlineexplicit

Definition at line 22 of file Model.h.

22 : wrapper_(std::make_unique<W>()) {}
std::unique_ptr< W > wrapper_
Definition: Model.h:49

◆ ~Model()

template<class W >
tfaot::Model< W >::~Model ( )
inline

Definition at line 25 of file Model.h.

References tfaot::Model< W >::wrapper_.

25 { wrapper_.reset(); };
std::unique_ptr< W > wrapper_
Definition: Model.h:49

Member Function Documentation

◆ ensureRule()

template<class W >
const BatchRule & Model::ensureRule ( size_t  batchSize)
private

Definition at line 70 of file Model.h.

References HLT_FULL_cff::batchSize.

70  {
71  // register a default rule if there is none yet for that batch size
74  }
76  }
const BatchRule & getRule(size_t batchSize) const
Definition: Batching.cc:85
std::unique_ptr< W > wrapper_
Definition: Model.h:49
BatchStrategy batchStrategy_
Definition: Model.h:50
bool hasRule(size_t batchSize) const
Definition: Batching.h:74
void setDefaultRule(size_t batchSize, const std::vector< size_t > &availableBatchSizes)
Definition: Batching.cc:101

◆ extractBatchOutput()

template<class W >
template<typename T >
void Model::extractBatchOutput ( size_t  batchSize,
size_t  batchIndex,
size_t  resultIndex,
std::vector< T > &  batchData 
) const
private

Definition at line 103 of file Model.h.

References HLT_FULL_cff::batchSize, and submitPVResolutionJobs::count.

106  {
107  size_t count = wrapper_->resultCountNoBatch(resultIndex);
108  const T* resPtr = wrapper_->template resultData<T>(batchSize, resultIndex) + batchIndex * count;
109  batchData.assign(resPtr, resPtr + count);
110  }
std::unique_ptr< W > wrapper_
Definition: Model.h:49
long double T

◆ getBatchStrategy()

template<class W >
const BatchStrategy& tfaot::Model< W >::getBatchStrategy ( ) const
inline

Definition at line 34 of file Model.h.

References tfaot::Model< W >::batchStrategy_.

34 { return batchStrategy_; }
BatchStrategy batchStrategy_
Definition: Model.h:50

◆ injectBatchInput()

template<class W >
template<typename T >
void Model::injectBatchInput ( size_t  batchSize,
size_t  batchIndex,
size_t  argIndex,
const std::vector< T > &  batchData 
)
private

Definition at line 86 of file Model.h.

References HLT_FULL_cff::batchSize, filterCSVwithJSON::copy, submitPVResolutionJobs::count, Exception, and mergeVDriftHistosByStation::name.

89  {
90  size_t count = wrapper_->argCountNoBatch(argIndex);
91  if (batchData.size() != count) {
92  throw cms::Exception("InputMismatch")
93  << "model '" << name() << "' received " << batchData.size() << " elements for argument " << argIndex
94  << ", but " << count << " are expected";
95  }
96  T* argPtr = wrapper_->template argData<T>(batchSize, argIndex) + batchIndex * count;
97  auto beg = batchData.cbegin();
98  std::copy(beg, beg + count, argPtr);
99  }
std::unique_ptr< W > wrapper_
Definition: Model.h:49
const std::string & name() const
Definition: Model.h:28
long double T

◆ name()

template<class W >
const std::string& tfaot::Model< W >::name ( void  ) const
inline

Definition at line 28 of file Model.h.

References tfaot::Model< W >::wrapper_.

Referenced by config.CFG::__str__(), and validation.Sample::digest().

28 { return wrapper_->name(); }
std::unique_ptr< W > wrapper_
Definition: Model.h:49

◆ reserveOutput()

template<class W >
template<typename T >
void Model::reserveOutput ( size_t  batchSize,
size_t  resultIndex,
std::vector< std::vector< T >> &  data 
) const
private

Definition at line 80 of file Model.h.

References HLT_FULL_cff::batchSize, and data.

80  {
81  data.resize(batchSize, std::vector<T>(wrapper_->resultCountNoBatch(resultIndex)));
82  }
std::unique_ptr< W > wrapper_
Definition: Model.h:49
char data[epos_bytes_allocation]
Definition: EPOS_Wrapper.h:80

◆ run()

template<class W >
template<typename... Outputs, typename... Inputs>
std::tuple< Outputs... > Model::run ( size_t  batchSize,
Inputs &&...  inputs 
)

Definition at line 114 of file Model.h.

References HLT_FULL_cff::batchSize, cms::cuda::bs, tfaot::createIndexLooper(), Exception, tfaot::BatchRule::getLastPadding(), tfaot::BatchRule::getSize(), mps_fire::i, PixelMapPlotter::inputs, mergeVDriftHistosByStation::name, tfaot::BatchRule::nSizes(), PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and mkLumiAveragedPlots::tuple.

114  {
115  // check number of inputs
116  size_t nInputs = sizeof...(Inputs);
117  if (nInputs != wrapper_->nArgs()) {
118  throw cms::Exception("InputMismatch")
119  << "model '" << name() << "' received " << nInputs << " inputs, but " << wrapper_->nArgs() << " are expected";
120  }
121 
122  // check number of outputs
123  size_t nOutputs = sizeof...(Outputs);
124  if (nOutputs != wrapper_->nResults()) {
125  throw cms::Exception("OutputMismatch") << "requested " << nOutputs << " from model '" << name() << "', but "
126  << wrapper_->nResults() << " are provided";
127  }
128 
129  // get the corresponding batch rule
130  const BatchRule& rule = ensureRule(batchSize);
131 
132  // create a callback that invokes lambdas over all outputs with normal indices
133  auto forEachOutput = createIndexLooper<sizeof...(Outputs)>();
134 
135  // reserve output arrays
136  std::tuple<Outputs...> outputs;
137  forEachOutput([&](auto resultIndex) { reserveOutput(batchSize, resultIndex, std::get<resultIndex>(outputs)); });
138 
139  // loop over particular batch sizes, copy input, evaluate and compose the output
140  size_t batchOffset = 0;
141  size_t nSizes = rule.nSizes();
142  for (size_t i = 0; i < nSizes; i++) {
143  // get actual model batch size and optional padding
144  size_t bs = rule.getSize(i);
145  size_t padding = (i == nSizes - 1) ? rule.getLastPadding() : 0;
146 
147  // fill inputs separately per batch element
148  for (size_t batchIndex = 0; batchIndex < bs - padding; batchIndex++) {
149  size_t argIndex = 0;
150  ([&] { injectBatchInput(bs, batchIndex, argIndex++, inputs[batchOffset + batchIndex]); }(), ...);
151  }
152 
153  // model evaluation
154  wrapper_->run(bs);
155 
156  // fill outputs separately per batch element
157  for (size_t batchIndex = 0; batchIndex < bs - padding; batchIndex++) {
158  forEachOutput([&](auto resultIndex) {
159  extractBatchOutput(bs, batchIndex, resultIndex, std::get<resultIndex>(outputs)[batchOffset + batchIndex]);
160  });
161  }
162 
163  batchOffset += bs;
164  }
165 
166  return outputs;
167  }
void extractBatchOutput(size_t batchSize, size_t batchIndex, size_t resultIndex, std::vector< T > &batchData) const
Definition: Model.h:103
std::unique_ptr< W > wrapper_
Definition: Model.h:49
void injectBatchInput(size_t batchSize, size_t batchIndex, size_t argIndex, const std::vector< T > &batchData)
Definition: Model.h:86
const std::string & name() const
Definition: Model.h:28
void reserveOutput(size_t batchSize, size_t resultIndex, std::vector< std::vector< T >> &data) const
Definition: Model.h:80
const BatchRule & ensureRule(size_t batchSize)
Definition: Model.h:70
auto createIndexLooper(std::index_sequence< Index... >)
Definition: Util.h:29

◆ setBatchRule() [1/2]

template<class W >
void tfaot::Model< W >::setBatchRule ( size_t  batchSize,
const std::vector< size_t > &  sizes,
size_t  lastPadding = 0 
)
inline

Definition at line 37 of file Model.h.

References HLT_FULL_cff::batchSize, tfaot::Model< W >::batchStrategy_, and tfaot::BatchStrategy::setRule().

37  {
38  batchStrategy_.setRule(BatchRule(batchSize, sizes, lastPadding));
39  }
void setRule(const BatchRule &rule)
Definition: Batching.h:68
BatchStrategy batchStrategy_
Definition: Model.h:50

◆ setBatchRule() [2/2]

template<class W >
void tfaot::Model< W >::setBatchRule ( const std::string &  batchRule)
inline

Definition at line 42 of file Model.h.

References tfaot::Model< W >::batchStrategy_, and tfaot::BatchStrategy::setRule().

42 { batchStrategy_.setRule(BatchRule(batchRule)); }
void setRule(const BatchRule &rule)
Definition: Batching.h:68
BatchStrategy batchStrategy_
Definition: Model.h:50

◆ setBatchStrategy()

template<class W >
void tfaot::Model< W >::setBatchStrategy ( const BatchStrategy strategy)
inline

Definition at line 31 of file Model.h.

References tfaot::Model< W >::batchStrategy_.

BatchStrategy batchStrategy_
Definition: Model.h:50
strategy
Definition: nnet_common.h:18

Member Data Documentation

◆ batchStrategy_

template<class W >
BatchStrategy tfaot::Model< W >::batchStrategy_
private

◆ wrapper_

template<class W >
std::unique_ptr<W> tfaot::Model< W >::wrapper_
private

Definition at line 49 of file Model.h.

Referenced by tfaot::Model< W >::name(), and tfaot::Model< W >::~Model().