CMS 3D CMS Logo

LossFunctions.h
Go to the documentation of this file.
1 // LossFunctions.h
2 // Here we define the different loss functions that can be used
3 // with the BDT system.
4 
5 #ifndef L1Trigger_L1TMuonEndCap_emtf_LossFunctions
6 #define L1Trigger_L1TMuonEndCap_emtf_LossFunctions
7 
8 #include "Event.h"
9 #include <string>
10 #include <algorithm>
11 #include <cmath>
12 
13 // ========================================================
14 // ================ Define the Interface ==================
15 //=========================================================
16 
17 namespace emtf {
18 
19  // Define the Interface
20  class LossFunction {
21  public:
22  // The gradient of the loss function.
23  // Each tree is a step in the direction of the gradient
24  // towards the minimum of the Loss Function.
25  virtual double target(Event* e) = 0;
26 
27  // The fit should minimize the loss function in each
28  // terminal node at each iteration.
29  virtual double fit(std::vector<Event*>& v) = 0;
30  virtual std::string name() = 0;
31  virtual int id() = 0;
32  virtual ~LossFunction() = default;
33  };
34 
35  // ========================================================
36  // ================ Least Squares =========================
37  // ========================================================
38 
39  class LeastSquares : public LossFunction {
40  public:
42  ~LeastSquares() override {}
43 
44  double target(Event* e) override {
45  // Each tree fits the residuals when using LeastSquares.
46  return e->trueValue - e->predictedValue;
47  }
48 
49  double fit(std::vector<Event*>& v) override {
50  // The average of the residuals minmizes the Loss Function for LS.
51 
52  double SUM = 0;
53  for (unsigned int i = 0; i < v.size(); i++) {
54  Event* e = v[i];
55  SUM += e->trueValue - e->predictedValue;
56  }
57 
58  return SUM / v.size();
59  }
60  std::string name() override { return "Least_Squares"; }
61  int id() override { return 1; }
62  };
63 
64  // ========================================================
65  // ============== Absolute Deviation ===================
66  // ========================================================
67 
69  public:
71  ~AbsoluteDeviation() override {}
72 
73  double target(Event* e) override {
74  // The gradient.
75  if ((e->trueValue - e->predictedValue) >= 0)
76  return 1;
77  else
78  return -1;
79  }
80 
81  double fit(std::vector<Event*>& v) override {
82  // The median of the residuals minimizes absolute deviation.
83  if (v.empty())
84  return 0;
85  std::vector<double> residuals(v.size());
86 
87  // Load the residuals into a vector.
88  for (unsigned int i = 0; i < v.size(); i++) {
89  Event* e = v[i];
90  residuals[i] = (e->trueValue - e->predictedValue);
91  }
92 
93  // Get the median and return it.
94  int median_loc = (residuals.size() - 1) / 2;
95 
96  // Odd.
97  if (residuals.size() % 2 != 0) {
98  std::nth_element(residuals.begin(), residuals.begin() + median_loc, residuals.end());
99  return residuals[median_loc];
100  }
101 
102  // Even.
103  else {
104  std::nth_element(residuals.begin(), residuals.begin() + median_loc, residuals.end());
105  double low = residuals[median_loc];
106  std::nth_element(residuals.begin() + median_loc + 1, residuals.begin() + median_loc + 1, residuals.end());
107  double high = residuals[median_loc + 1];
108  return (high + low) / 2;
109  }
110  }
111  std::string name() override { return "Absolute_Deviation"; }
112  int id() override { return 2; }
113  };
114 
115  // ========================================================
116  // ============== Huber ================================
117  // ========================================================
118 
119  class Huber : public LossFunction {
120  public:
121  Huber() {}
122  ~Huber() override {}
123 
124  double quantile;
126 
127  double target(Event* e) override {
128  // The gradient of the loss function.
129 
130  if (std::abs(e->trueValue - e->predictedValue) <= quantile)
131  return (e->trueValue - e->predictedValue);
132  else
133  return quantile * (((e->trueValue - e->predictedValue) > 0) ? 1.0 : -1.0);
134  }
135 
136  double fit(std::vector<Event*>& v) override {
137  // The constant fit that minimizes Huber in a region.
138 
139  quantile = calculateQuantile(v, 0.7);
141 
142  double x = 0;
143  for (unsigned int i = 0; i < v.size(); i++) {
144  Event* e = v[i];
145  double residual = e->trueValue - e->predictedValue;
146  double diff = residual - residual_median;
147  x += ((diff > 0) ? 1.0 : -1.0) * std::min(quantile, std::abs(diff));
148  }
149 
150  return (residual_median + x / v.size());
151  }
152 
153  std::string name() override { return "Huber"; }
154  int id() override { return 3; }
155 
156  double calculateQuantile(std::vector<Event*>& v, double whichQuantile) {
157  // Container for the residuals.
158  std::vector<double> residuals(v.size());
159 
160  // Load the residuals into a vector.
161  for (unsigned int i = 0; i < v.size(); i++) {
162  Event* e = v[i];
163  residuals[i] = std::abs(e->trueValue - e->predictedValue);
164  }
165 
166  std::sort(residuals.begin(), residuals.end());
167  unsigned int quantile_location = whichQuantile * (residuals.size() - 1);
168  return residuals[quantile_location];
169  }
170  };
171 
172  // ========================================================
173  // ============== Percent Error ===========================
174  // ========================================================
175 
177  public:
179  ~PercentErrorSquared() override {}
180 
181  double target(Event* e) override {
182  // The gradient of the squared percent error.
183  return (e->trueValue - e->predictedValue) / (e->trueValue * e->trueValue);
184  }
185 
186  double fit(std::vector<Event*>& v) override {
187  // The average of the weighted residuals minimizes the squared percent error.
188  // Weight(i) = 1/true(i)^2.
189 
190  double SUMtop = 0;
191  double SUMbottom = 0;
192 
193  for (unsigned int i = 0; i < v.size(); i++) {
194  Event* e = v[i];
195  SUMtop += (e->trueValue - e->predictedValue) / (e->trueValue * e->trueValue);
196  SUMbottom += 1 / (e->trueValue * e->trueValue);
197  }
198 
199  return SUMtop / SUMbottom;
200  }
201  std::string name() override { return "Percent_Error"; }
202  int id() override { return 4; }
203  };
204 
205 } // namespace emtf
206 
207 #endif
double fit(std::vector< Event *> &v) override
std::string name() override
std::string name() override
Definition: Event.h:15
double fit(std::vector< Event *> &v) override
Definition: LossFunctions.h:49
double target(Event *e) override
Definition: LossFunctions.h:73
double target(Event *e) override
std::string name() override
virtual std::string name()=0
int id() override
Definition: LossFunctions.h:61
double quantile
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
double target(Event *e) override
Definition: LossFunctions.h:44
~Huber() override
virtual double target(Event *e)=0
#define SUM(A, B)
double calculateQuantile(std::vector< Event *> &v, double whichQuantile)
virtual double fit(std::vector< Event *> &v)=0
~LeastSquares() override
Definition: LossFunctions.h:42
double fit(std::vector< Event *> &v) override
Definition: LossFunctions.h:81
std::string name() override
Definition: LossFunctions.h:60
double target(Event *e) override
float x
~AbsoluteDeviation() override
Definition: LossFunctions.h:71
int id() override
double fit(std::vector< Event *> &v) override
double residual_median
virtual int id()=0
virtual ~LossFunction()=default