CMS 3D CMS Logo

List of all members | Public Member Functions | Public Attributes
emtf::Huber Class Reference

#include <LossFunctions.h>

Inheritance diagram for emtf::Huber:
emtf::LossFunction

Public Member Functions

double calculateQuantile (std::vector< Event *> &v, double whichQuantile)
 
double fit (std::vector< Event *> &v) override
 
 Huber ()
 
int id () override
 
std::string name () override
 
double target (Event *e) override
 
 ~Huber () override
 
- Public Member Functions inherited from emtf::LossFunction
virtual ~LossFunction ()=default
 

Public Attributes

double quantile
 
double residual_median
 

Detailed Description

Definition at line 119 of file LossFunctions.h.

Constructor & Destructor Documentation

◆ Huber()

emtf::Huber::Huber ( )
inline

Definition at line 121 of file LossFunctions.h.

121 {}

◆ ~Huber()

emtf::Huber::~Huber ( )
inlineoverride

Definition at line 122 of file LossFunctions.h.

122 {}

Member Function Documentation

◆ calculateQuantile()

double emtf::Huber::calculateQuantile ( std::vector< Event *> &  v,
double  whichQuantile 
)
inline

Definition at line 156 of file LossFunctions.h.

References funct::abs(), MillePedeFileConverter_cfg::e, mps_fire::i, jetsAK4_CHS_cff::sort, and findQualityFiles::v.

Referenced by fit().

156  {
157  // Container for the residuals.
158  std::vector<double> residuals(v.size());
159 
160  // Load the residuals into a vector.
161  for (unsigned int i = 0; i < v.size(); i++) {
162  Event* e = v[i];
163  residuals[i] = std::abs(e->trueValue - e->predictedValue);
164  }
165 
166  std::sort(residuals.begin(), residuals.end());
167  unsigned int quantile_location = whichQuantile * (residuals.size() - 1);
168  return residuals[quantile_location];
169  }
Abs< T >::type abs(const T &t)
Definition: Abs.h:22

◆ fit()

double emtf::Huber::fit ( std::vector< Event *> &  v)
inlineoverridevirtual

Implements emtf::LossFunction.

Definition at line 136 of file LossFunctions.h.

References funct::abs(), calculateQuantile(), change_name::diff, MillePedeFileConverter_cfg::e, mps_fire::i, SiStripPI::min, quantile, residual_median, findQualityFiles::v, and x.

Referenced by trackingPlots.Iteration::modules().

136  {
137  // The constant fit that minimizes Huber in a region.
138 
139  quantile = calculateQuantile(v, 0.7);
141 
142  double x = 0;
143  for (unsigned int i = 0; i < v.size(); i++) {
144  Event* e = v[i];
145  double residual = e->trueValue - e->predictedValue;
146  double diff = residual - residual_median;
147  x += ((diff > 0) ? 1.0 : -1.0) * std::min(quantile, std::abs(diff));
148  }
149 
150  return (residual_median + x / v.size());
151  }
double quantile
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
double calculateQuantile(std::vector< Event *> &v, double whichQuantile)
double residual_median

◆ id()

int emtf::Huber::id ( void  )
inlineoverridevirtual

Implements emtf::LossFunction.

Definition at line 154 of file LossFunctions.h.

154 { return 3; }

◆ name()

std::string emtf::Huber::name ( void  )
inlineoverridevirtual

Implements emtf::LossFunction.

Definition at line 153 of file LossFunctions.h.

Referenced by config.CFG::__str__(), validation.Sample::digest(), and VIDSelectorBase.VIDSelectorBase::initialize().

153 { return "Huber"; }

◆ target()

double emtf::Huber::target ( Event e)
inlineoverridevirtual

Implements emtf::LossFunction.

Definition at line 127 of file LossFunctions.h.

References funct::abs(), MillePedeFileConverter_cfg::e, and quantile.

127  {
128  // The gradient of the loss function.
129 
130  if (std::abs(e->trueValue - e->predictedValue) <= quantile)
131  return (e->trueValue - e->predictedValue);
132  else
133  return quantile * (((e->trueValue - e->predictedValue) > 0) ? 1.0 : -1.0);
134  }
double quantile
Abs< T >::type abs(const T &t)
Definition: Abs.h:22

Member Data Documentation

◆ quantile

double emtf::Huber::quantile

Definition at line 124 of file LossFunctions.h.

Referenced by fit(), and target().

◆ residual_median

double emtf::Huber::residual_median

Definition at line 125 of file LossFunctions.h.

Referenced by fit().