CMS 3D CMS Logo

LossFunctions.h
Go to the documentation of this file.
1 // LossFunctions.h
2 // Here we define the different loss functions that can be used
3 // with the BDT system.
4 
5 #ifndef L1Trigger_L1TMuonEndCap_emtf_LossFunctions
6 #define L1Trigger_L1TMuonEndCap_emtf_LossFunctions
7 
8 #include "Event.h"
9 #include <string>
10 #include <algorithm>
11 #include <cmath>
12 
13 // ========================================================
14 // ================ Define the Interface ==================
15 //=========================================================
16 
17 namespace emtf {
18 
19 // Define the Interface
21 {
22  public:
23 
24  // The gradient of the loss function.
25  // Each tree is a step in the direction of the gradient
26  // towards the minimum of the Loss Function.
27  virtual double target(Event* e) = 0;
28 
29  // The fit should minimize the loss function in each
30  // terminal node at each iteration.
31  virtual double fit(std::vector<Event*>& v) = 0;
32  virtual std::string name() = 0;
33  virtual int id() = 0;
34  virtual ~LossFunction() = default;
35 };
36 
37 // ========================================================
38 // ================ Least Squares =========================
39 // ========================================================
40 
41 class LeastSquares : public LossFunction
42 {
43  public:
45  ~LeastSquares() override{}
46 
47  double target(Event* e) override
48  {
49  // Each tree fits the residuals when using LeastSquares.
50  return e->trueValue - e->predictedValue;
51  }
52 
53  double fit(std::vector<Event*>& v) override
54  {
55  // The average of the residuals minmizes the Loss Function for LS.
56 
57  double SUM = 0;
58  for(unsigned int i=0; i<v.size(); i++)
59  {
60  Event* e = v[i];
61  SUM += e->trueValue - e->predictedValue;
62  }
63 
64  return SUM/v.size();
65  }
66  std::string name() override { return "Least_Squares"; }
67  int id() override{ return 1; }
68 
69 };
70 
71 // ========================================================
72 // ============== Absolute Deviation ===================
73 // ========================================================
74 
76 {
77  public:
79  ~AbsoluteDeviation() override{}
80 
81  double target(Event* e) override
82  {
83  // The gradient.
84  if ((e->trueValue - e->predictedValue) >= 0)
85  return 1;
86  else
87  return -1;
88  }
89 
90  double fit(std::vector<Event*>& v) override
91  {
92  // The median of the residuals minimizes absolute deviation.
93  if(v.empty()) return 0;
94  std::vector<double> residuals(v.size());
95 
96  // Load the residuals into a vector.
97  for(unsigned int i=0; i<v.size(); i++)
98  {
99  Event* e = v[i];
100  residuals[i] = (e->trueValue - e->predictedValue);
101  }
102 
103  // Get the median and return it.
104  int median_loc = (residuals.size()-1)/2;
105 
106  // Odd.
107  if(residuals.size()%2 != 0)
108  {
109  std::nth_element(residuals.begin(), residuals.begin()+median_loc, residuals.end());
110  return residuals[median_loc];
111  }
112 
113  // Even.
114  else
115  {
116  std::nth_element(residuals.begin(), residuals.begin()+median_loc, residuals.end());
117  double low = residuals[median_loc];
118  std::nth_element(residuals.begin()+median_loc+1, residuals.begin()+median_loc+1, residuals.end());
119  double high = residuals[median_loc+1];
120  return (high + low)/2;
121  }
122  }
123  std::string name() override { return "Absolute_Deviation"; }
124  int id() override{ return 2; }
125 };
126 
127 // ========================================================
128 // ============== Huber ================================
129 // ========================================================
130 
131 class Huber : public LossFunction
132 {
133  public:
134  Huber(){}
135  ~Huber() override{}
136 
137  double quantile;
139 
140  double target(Event* e) override
141  {
142  // The gradient of the loss function.
143 
144  if (std::abs(e->trueValue - e->predictedValue) <= quantile)
145  return (e->trueValue - e->predictedValue);
146  else
147  return quantile*(((e->trueValue - e->predictedValue) > 0)?1.0:-1.0);
148  }
149 
150  double fit(std::vector<Event*>& v) override
151  {
152  // The constant fit that minimizes Huber in a region.
153 
154  quantile = calculateQuantile(v, 0.7);
155  residual_median = calculateQuantile(v, 0.5);
156 
157  double x = 0;
158  for(unsigned int i=0; i<v.size(); i++)
159  {
160  Event* e = v[i];
161  double residual = e->trueValue - e->predictedValue;
162  double diff = residual - residual_median;
163  x += ((diff > 0)?1.0:-1.0)*std::min(quantile, std::abs(diff));
164  }
165 
166  return (residual_median + x/v.size());
167 
168  }
169 
170  std::string name() override { return "Huber"; }
171  int id() override{ return 3; }
172 
173  double calculateQuantile(std::vector<Event*>& v, double whichQuantile)
174  {
175  // Container for the residuals.
176  std::vector<double> residuals(v.size());
177 
178  // Load the residuals into a vector.
179  for(unsigned int i=0; i<v.size(); i++)
180  {
181  Event* e = v[i];
182  residuals[i] = std::abs(e->trueValue - e->predictedValue);
183  }
184 
185  std::sort(residuals.begin(), residuals.end());
186  unsigned int quantile_location = whichQuantile*(residuals.size()-1);
187  return residuals[quantile_location];
188  }
189 };
190 
191 // ========================================================
192 // ============== Percent Error ===========================
193 // ========================================================
194 
196 {
197  public:
199  ~PercentErrorSquared() override{}
200 
201  double target(Event* e) override
202  {
203  // The gradient of the squared percent error.
204  return (e->trueValue - e->predictedValue)/(e->trueValue * e->trueValue);
205  }
206 
207  double fit(std::vector<Event*>& v) override
208  {
209  // The average of the weighted residuals minimizes the squared percent error.
210  // Weight(i) = 1/true(i)^2.
211 
212  double SUMtop = 0;
213  double SUMbottom = 0;
214 
215  for(unsigned int i=0; i<v.size(); i++)
216  {
217  Event* e = v[i];
218  SUMtop += (e->trueValue - e->predictedValue)/(e->trueValue*e->trueValue);
219  SUMbottom += 1/(e->trueValue*e->trueValue);
220  }
221 
222  return SUMtop/SUMbottom;
223  }
224  std::string name() override { return "Percent_Error"; }
225  int id() override{ return 4; }
226 };
227 
228 } // end of emtf namespace
229 
230 #endif
double fit(std::vector< Event * > &v) override
Definition: LossFunctions.h:53
std::string name() override
std::string name() override
Definition: Event.h:15
double target(Event *e) override
Definition: LossFunctions.h:81
virtual double fit(std::vector< Event * > &v)=0
double target(Event *e) override
std::string name() override
virtual std::string name()=0
double fit(std::vector< Event * > &v) override
double predictedValue
Definition: Event.h:21
int id() override
Definition: LossFunctions.h:67
double quantile
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
double target(Event *e) override
Definition: LossFunctions.h:47
T min(T a, T b)
Definition: MathUtil.h:58
~Huber() override
double trueValue
Definition: Event.h:20
double fit(std::vector< Event * > &v) override
virtual double target(Event *e)=0
#define SUM(A, B)
~LeastSquares() override
Definition: LossFunctions.h:45
std::string name() override
Definition: LossFunctions.h:66
double fit(std::vector< Event * > &v) override
Definition: LossFunctions.h:90
double target(Event *e) override
~AbsoluteDeviation() override
Definition: LossFunctions.h:79
int id() override
double calculateQuantile(std::vector< Event * > &v, double whichQuantile)
double residual_median
virtual int id()=0
virtual ~LossFunction()=default