CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
EgammaDNNHelper.h
Go to the documentation of this file.
1 #ifndef RecoEgamma_ElectronTools_EgammaDNNHelper_h
2 #define RecoEgamma_ElectronTools_EgammaDNNHelper_h
3 
5 #include <vector>
6 #include <memory>
7 #include <string>
8 #include <functional>
9 
10 //author: Davide Valsecchi
11 //description:
12 // Handles Tensorflow DNN graphs and variables scaler configuration.
13 // To be used for PFID egamma DNNs
14 
15 namespace egammaTools {
16 
20  std::vector<std::string> modelsFiles;
21  std::vector<std::string> scalersFiles;
23  };
24 
26  /* Each input variable is represented by the tuple <varname, standardization_type, par1, par2>
27  * The standardization_type can be:
28  * 0 = Do not scale the variable
29  * 1 = standard norm. par1=mean, par2=std
30  * 2 = MinMax. par1=min, par2=max */
33  float par1;
34  float par2;
35  };
36 
37  // Model for function to be used on the specific candidate to get the model
38  // index to be used for the evaluation.
39  typedef std::function<uint(const std::map<std::string, float>&)> ModelSelector;
40 
42  public:
43  EgammaDNNHelper(const DNNConfiguration&, const ModelSelector& sel, const std::vector<std::string>& availableVars);
44 
45  std::vector<tensorflow::Session*> getSessions() const;
46  // Function getting the input vector for a specific electron, already scaled
47  // together with the model index it has to be used.
48  // The model index is determined by the ModelSelector functor passed in the constructor
49  // which has access to all the variables.
50  std::pair<uint, std::vector<float>> getScaledInputs(const std::map<std::string, float>& variables) const;
51 
52  std::vector<std::vector<float>> evaluate(const std::vector<std::map<std::string, float>>& candidates,
53  const std::vector<tensorflow::Session*>& sessions) const;
54 
55  private:
56  void initTensorFlowGraphs();
57  void initScalerFiles(const std::vector<std::string>& availableVars);
58 
61  // Number of models handled by the object
63  // Number of inputs for each loaded model
64  std::vector<uint> nInputs_;
65 
66  std::vector<std::unique_ptr<const tensorflow::GraphDef>> graphDefs_;
67 
68  // List of input variables for each of the model;
69  std::vector<std::vector<ScalerConfiguration>> featuresMap_;
70  };
71 
72 }; // namespace egammaTools
73 
74 #endif
std::vector< std::unique_ptr< const tensorflow::GraphDef > > graphDefs_
std::vector< std::vector< float > > evaluate(const std::vector< std::map< std::string, float >> &candidates, const std::vector< tensorflow::Session * > &sessions) const
EgammaDNNHelper(const DNNConfiguration &, const ModelSelector &sel, const std::vector< std::string > &availableVars)
std::vector< std::string > modelsFiles
std::vector< uint > nInputs_
std::pair< uint, std::vector< float > > getScaledInputs(const std::map< std::string, float > &variables) const
std::vector< std::vector< ScalerConfiguration > > featuresMap_
std::vector< tensorflow::Session * > getSessions() const
void initScalerFiles(const std::vector< std::string > &availableVars)
std::vector< std::string > scalersFiles
std::function< uint(const std::map< std::string, float > &)> ModelSelector
const ModelSelector modelSelector_
const DNNConfiguration cfg_