CMS 3D CMS Logo

nnet_mult.h
Go to the documentation of this file.
1 #ifndef NNET_MULT_H_
2 #define NNET_MULT_H_
3 
4 #include "nnet_common.h"
5 #include <iostream>
6 #include <math.h>
7 
8 namespace nnet {
9 
10  constexpr int ceillog2(int x) { return (x <= 2) ? 1 : 1 + ceillog2((x + 1) / 2); }
11 
12  namespace product {
13 
14  /* ---
15  * different methods to perform the product of input and weight, depending on the
16  * types of each.
17  * --- */
18 
19  class Product {};
20 
21  template <class x_T, class w_T>
22  class both_binary : public Product {
23  public:
24  static x_T product(x_T a, w_T w) { return a == w; }
25  };
26 
27  template <class x_T, class w_T>
28  class weight_binary : public Product {
29  public:
30  static auto product(x_T a, w_T w) -> decltype(-a) {
31  if (w == 0)
32  return -a;
33  else
34  return a;
35  }
36  };
37 
38  template <class x_T, class w_T>
39  class data_binary : public Product {
40  public:
41  static auto product(x_T a, w_T w) -> decltype(-w) {
42  if (a == 0)
43  return -w;
44  else
45  return w;
46  }
47  };
48 
49  template <class x_T, class w_T>
50  class weight_ternary : public Product {
51  public:
52  static auto product(x_T a, w_T w) -> decltype(-a) {
53  if (w == 0)
54  return 0;
55  else if (w == -1)
56  return -a;
57  else
58  return a; // if(w == 1)
59  }
60  };
61 
62  template <class x_T, class w_T>
63  class mult : public Product {
64  public:
65  static auto product(x_T a, w_T w) -> decltype(a * w) { return a * w; }
66  };
67 
68  template <class x_T, class w_T>
69  class weight_exponential : public Product {
70  public:
71  using r_T =
72  ap_fixed<2 * (decltype(w_T::weight)::width + x_T::width), (decltype(w_T::weight)::width + x_T::width)>;
73  static r_T product(x_T a, w_T w) {
74  // Shift by the exponent. Negative weights shift right
75  r_T y = static_cast<r_T>(a) << w.weight;
76 
77  // Negate or not depending on weight sign
78  return w.sign == 1 ? y : static_cast<r_T>(-y);
79  }
80  };
81 
82  } // namespace product
83 
84  template <class data_T, class res_T, typename CONFIG_T>
85  inline typename std::enable_if<std::is_same<data_T, ap_uint<1>>::value &&
86  std::is_same<typename CONFIG_T::weight_t, ap_uint<1>>::value,
87  ap_int<nnet::ceillog2(CONFIG_T::n_in) + 2>>::type
88  cast(typename CONFIG_T::accum_t x) {
89  return (ap_int<nnet::ceillog2(CONFIG_T::n_in) + 2>)(x - CONFIG_T::n_in / 2) * 2;
90  }
91 
92  template <class data_T, class res_T, typename CONFIG_T>
93  inline typename std::enable_if<std::is_same<data_T, ap_uint<1>>::value &&
94  !std::is_same<typename CONFIG_T::weight_t, ap_uint<1>>::value,
95  res_T>::type
96  cast(typename CONFIG_T::accum_t x) {
97  return (res_T)x;
98  }
99 
100  template <class data_T, class res_T, typename CONFIG_T>
101  inline typename std::enable_if<(!std::is_same<data_T, ap_uint<1>>::value), res_T>::type cast(
102  typename CONFIG_T::accum_t x) {
103  return (res_T)x;
104  }
105 
106 } // namespace nnet
107 
108 #endif
constexpr int ceillog2(int x)
Definition: nnet_mult.h:10
static r_T product(x_T a, w_T w)
Definition: nnet_mult.h:73
T w() const
std::enable_if< std::is_same< data_T, ap_uint< 1 > >::value &&std::is_same< typename CONFIG_T::weight_t, ap_uint< 1 > >::value, ap_int< nnet::ceillog2(CONFIG_T::n_in)+2 > >::type cast(typename CONFIG_T::accum_t x)
Definition: nnet_mult.h:88
static auto product(x_T a, w_T w) -> decltype(a *w)
Definition: nnet_mult.h:65
ap_fixed< 2 *(decltype(w_T::weight)::width+x_T::width),(decltype(w_T::weight)::width+x_T::width)> r_T
Definition: nnet_mult.h:72
static x_T product(x_T a, w_T w)
Definition: nnet_mult.h:24
static auto product(x_T a, w_T w) -> decltype(-a)
Definition: nnet_mult.h:30
double a
Definition: hdecay.h:121
static auto product(x_T a, w_T w) -> decltype(-w)
Definition: nnet_mult.h:41
float x
static auto product(x_T a, w_T w) -> decltype(-a)
Definition: nnet_mult.h:52