CMS 3D CMS Logo

List of all members | Public Member Functions | Private Member Functions | Private Attributes
tfaot::Model< W > Class Template Reference

#include <Model.h>

Public Member Functions

const BatchStrategygetBatchStrategy () const
 
 Model ()
 
const std::string & name () const
 
template<typename... Outputs, typename... Inputs>
std::tuple< Outputs... > run (size_t batchSize, Inputs &&... inputs)
 
void setBatchRule (size_t batchSize, const std::vector< size_t > &sizes, size_t lastPadding=0)
 
void setBatchStrategy (const BatchStrategy &strategy)
 
 ~Model ()
 

Private Member Functions

const BatchRuleensureRule (size_t batchSize)
 
template<typename T >
void extractBatchOutput (size_t batchSize, size_t batchIndex, size_t resultIndex, std::vector< T > &batchData) const
 
template<typename T >
void injectBatchInput (size_t batchSize, size_t batchIndex, size_t argIndex, const std::vector< T > &batchData)
 
template<typename T >
void reserveOutput (size_t batchSize, size_t resultIndex, std::vector< std::vector< T >> &data) const
 

Private Attributes

BatchStrategy batchStrategy_
 
std::unique_ptr< W > wrapper_
 

Detailed Description

template<class W>
class tfaot::Model< W >

Definition at line 19 of file Model.h.

Constructor & Destructor Documentation

◆ Model()

template<class W >
tfaot::Model< W >::Model ( )
inlineexplicit

Definition at line 22 of file Model.h.

22 : wrapper_(std::make_unique<W>()) {}
std::unique_ptr< W > wrapper_
Definition: Model.h:46

◆ ~Model()

template<class W >
tfaot::Model< W >::~Model ( )
inline

Definition at line 25 of file Model.h.

References tfaot::Model< W >::wrapper_.

25 { wrapper_.reset(); };
std::unique_ptr< W > wrapper_
Definition: Model.h:46

Member Function Documentation

◆ ensureRule()

template<class W >
const BatchRule & Model::ensureRule ( size_t  batchSize)
private

Definition at line 67 of file Model.h.

References HLT_FULL_cff::batchSize.

67  {
68  // register a default rule if there is none yet for that batch size
71  }
73  }
const BatchRule & getRule(size_t batchSize) const
Definition: Batching.cc:48
std::unique_ptr< W > wrapper_
Definition: Model.h:46
BatchStrategy batchStrategy_
Definition: Model.h:47
bool hasRule(size_t batchSize) const
Definition: Batching.h:64
void setDefaultRule(size_t batchSize, const std::vector< size_t > &availableBatchSizes)
Definition: Batching.cc:64

◆ extractBatchOutput()

template<class W >
template<typename T >
void Model::extractBatchOutput ( size_t  batchSize,
size_t  batchIndex,
size_t  resultIndex,
std::vector< T > &  batchData 
) const
private

Definition at line 100 of file Model.h.

References HLT_FULL_cff::batchSize, and submitPVResolutionJobs::count.

103  {
104  size_t count = wrapper_->resultCountNoBatch(resultIndex);
105  const T* resPtr = wrapper_->template resultData<T>(batchSize, resultIndex) + batchIndex * count;
106  batchData.assign(resPtr, resPtr + count);
107  }
std::unique_ptr< W > wrapper_
Definition: Model.h:46
long double T

◆ getBatchStrategy()

template<class W >
const BatchStrategy& tfaot::Model< W >::getBatchStrategy ( ) const
inline

Definition at line 34 of file Model.h.

References tfaot::Model< W >::batchStrategy_.

34 { return batchStrategy_; }
BatchStrategy batchStrategy_
Definition: Model.h:47

◆ injectBatchInput()

template<class W >
template<typename T >
void Model::injectBatchInput ( size_t  batchSize,
size_t  batchIndex,
size_t  argIndex,
const std::vector< T > &  batchData 
)
private

Definition at line 83 of file Model.h.

References HLT_FULL_cff::batchSize, filterCSVwithJSON::copy, submitPVResolutionJobs::count, Exception, and Skims_PA_cff::name.

86  {
87  size_t count = wrapper_->argCountNoBatch(argIndex);
88  if (batchData.size() != count) {
89  throw cms::Exception("InputMismatch")
90  << "model '" << name() << "' received " << batchData.size() << " elements for argument " << argIndex
91  << ", but " << count << " are expected";
92  }
93  T* argPtr = wrapper_->template argData<T>(batchSize, argIndex) + batchIndex * count;
94  auto beg = batchData.cbegin();
95  std::copy(beg, beg + count, argPtr);
96  }
std::unique_ptr< W > wrapper_
Definition: Model.h:46
const std::string & name() const
Definition: Model.h:28
long double T

◆ name()

template<class W >
const std::string& tfaot::Model< W >::name ( void  ) const
inline

Definition at line 28 of file Model.h.

References tfaot::Model< W >::wrapper_.

Referenced by config.CFG::__str__(), and validation.Sample::digest().

28 { return wrapper_->name(); }
std::unique_ptr< W > wrapper_
Definition: Model.h:46

◆ reserveOutput()

template<class W >
template<typename T >
void Model::reserveOutput ( size_t  batchSize,
size_t  resultIndex,
std::vector< std::vector< T >> &  data 
) const
private

Definition at line 77 of file Model.h.

References HLT_FULL_cff::batchSize, and data.

77  {
78  data.resize(batchSize, std::vector<T>(wrapper_->resultCountNoBatch(resultIndex)));
79  }
std::unique_ptr< W > wrapper_
Definition: Model.h:46
char data[epos_bytes_allocation]
Definition: EPOS_Wrapper.h:80

◆ run()

template<class W >
template<typename... Outputs, typename... Inputs>
std::tuple< Outputs... > Model::run ( size_t  batchSize,
Inputs &&...  inputs 
)

Definition at line 111 of file Model.h.

References HLT_FULL_cff::batchSize, cms::cuda::bs, tfaot::createIndexLooper(), Exception, tfaot::BatchRule::getLastPadding(), tfaot::BatchRule::getSize(), mps_fire::i, PixelMapPlotter::inputs, Skims_PA_cff::name, tfaot::BatchRule::nSizes(), PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and mkLumiAveragedPlots::tuple.

111  {
112  // check number of inputs
113  size_t nInputs = sizeof...(Inputs);
114  if (nInputs != wrapper_->nArgs()) {
115  throw cms::Exception("InputMismatch")
116  << "model '" << name() << "' received " << nInputs << " inputs, but " << wrapper_->nArgs() << " are expected";
117  }
118 
119  // check number of outputs
120  size_t nOutputs = sizeof...(Outputs);
121  if (nOutputs != wrapper_->nResults()) {
122  throw cms::Exception("OutputMismatch") << "requested " << nOutputs << " from model '" << name() << "', but "
123  << wrapper_->nResults() << " are provided";
124  }
125 
126  // get the corresponding batch rule
127  const BatchRule& rule = ensureRule(batchSize);
128 
129  // create a callback that invokes lambdas over all outputs with normal indices
130  auto forEachOutput = createIndexLooper<sizeof...(Outputs)>();
131 
132  // reserve output arrays
133  std::tuple<Outputs...> outputs;
134  forEachOutput([&](auto resultIndex) { reserveOutput(batchSize, resultIndex, std::get<resultIndex>(outputs)); });
135 
136  // loop over particular batch sizes, copy input, evaluate and compose the output
137  size_t batchOffset = 0;
138  size_t nSizes = rule.nSizes();
139  for (size_t i = 0; i < nSizes; i++) {
140  // get actual model batch size and optional padding
141  size_t bs = rule.getSize(i);
142  size_t padding = (i == nSizes - 1) ? rule.getLastPadding() : 0;
143 
144  // fill inputs separately per batch element
145  for (size_t batchIndex = 0; batchIndex < bs - padding; batchIndex++) {
146  size_t argIndex = 0;
147  ([&] { injectBatchInput(bs, batchIndex, argIndex++, inputs[batchOffset + batchIndex]); }(), ...);
148  }
149 
150  // model evaluation
151  wrapper_->run(bs);
152 
153  // fill outputs separately per batch element
154  for (size_t batchIndex = 0; batchIndex < bs - padding; batchIndex++) {
155  forEachOutput([&](auto resultIndex) {
156  extractBatchOutput(bs, batchIndex, resultIndex, std::get<resultIndex>(outputs)[batchOffset + batchIndex]);
157  });
158  }
159 
160  batchOffset += bs;
161  }
162 
163  return outputs;
164  }
void extractBatchOutput(size_t batchSize, size_t batchIndex, size_t resultIndex, std::vector< T > &batchData) const
Definition: Model.h:100
std::unique_ptr< W > wrapper_
Definition: Model.h:46
void injectBatchInput(size_t batchSize, size_t batchIndex, size_t argIndex, const std::vector< T > &batchData)
Definition: Model.h:83
const std::string & name() const
Definition: Model.h:28
void reserveOutput(size_t batchSize, size_t resultIndex, std::vector< std::vector< T >> &data) const
Definition: Model.h:77
const BatchRule & ensureRule(size_t batchSize)
Definition: Model.h:67
auto createIndexLooper(std::index_sequence< Index... >)
Definition: Util.h:29

◆ setBatchRule()

template<class W >
void tfaot::Model< W >::setBatchRule ( size_t  batchSize,
const std::vector< size_t > &  sizes,
size_t  lastPadding = 0 
)
inline

Definition at line 37 of file Model.h.

References HLT_FULL_cff::batchSize, tfaot::Model< W >::batchStrategy_, and tfaot::BatchStrategy::setRule().

37  {
38  batchStrategy_.setRule(BatchRule(batchSize, sizes, lastPadding));
39  }
void setRule(const BatchRule &rule)
Definition: Batching.h:61
BatchStrategy batchStrategy_
Definition: Model.h:47

◆ setBatchStrategy()

template<class W >
void tfaot::Model< W >::setBatchStrategy ( const BatchStrategy strategy)
inline

Definition at line 31 of file Model.h.

References tfaot::Model< W >::batchStrategy_.

BatchStrategy batchStrategy_
Definition: Model.h:47
strategy
Definition: nnet_common.h:18

Member Data Documentation

◆ batchStrategy_

template<class W >
BatchStrategy tfaot::Model< W >::batchStrategy_
private

◆ wrapper_

template<class W >
std::unique_ptr<W> tfaot::Model< W >::wrapper_
private

Definition at line 46 of file Model.h.

Referenced by tfaot::Model< W >::name(), and tfaot::Model< W >::~Model().