CMS 3D CMS Logo

List of all members | Public Member Functions | Private Attributes
emtf::Forest Class Reference

#include <Forest.h>

Public Member Functions

void appendCorrection (std::vector< Event * > &eventsp, int treenum)
 
void appendCorrection (Event *e, int treenum)
 
void doRegression (int nodeLimit, int treeLimit, double learningRate, LossFunction *l, const char *savetreesdirectory, bool saveTrees)
 
void doStochasticRegression (int nodeLimit, int treeLimit, double learningRate, double fraction, LossFunction *l)
 
 Forest ()
 
 Forest (std::vector< Event * > &trainingEvents)
 
 Forest (const Forest &forest)
 
 Forest (Forest &&forest)=default
 
void generate (int numTrainEvents, int numTestEvents, double sigma)
 
std::vector< Event * > getTrainingEvents ()
 
TreegetTree (unsigned int i)
 
void listEvents (std::vector< std::vector< Event * > > &e)
 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! More...
 
void loadForestFromXML (const char *directory, unsigned int numTrees)
 
void loadFromCondPayload (const L1TMuonEndCapForest::DForest &payload)
 
Forestoperator= (const Forest &forest)
 
void predictEvent (Event *e, unsigned int trees)
 
void predictEvents (std::vector< Event * > &eventsp, unsigned int trees)
 
void prepareRandomSubsample (double fraction)
 
void rankVariables (std::vector< int > &rank)
 
void saveSplitValues (const char *savefilename)
 
void setTrainingEvents (std::vector< Event * > &trainingEvents)
 
unsigned int size ()
 
void sortEventVectors (std::vector< std::vector< Event * > > &e)
 
void updateEvents (Tree *tree)
 
void updateRegTargets (Tree *tree, double learningRate, LossFunction *l)
 
 ~Forest ()
 

Private Attributes

std::vector< std::vector< Event * > > events
 
std::vector< std::vector< Event * > > subSample
 
std::vector< Tree * > trees
 

Detailed Description

Definition at line 12 of file Forest.h.

Constructor & Destructor Documentation

Forest::Forest ( )

Definition at line 40 of file Forest.cc.

References events.

41 {
42  events = std::vector< std::vector<Event*> >(1);
43 }
std::vector< std::vector< Event * > > events
Definition: Forest.h:66
Forest::Forest ( std::vector< Event * > &  trainingEvents)

Definition at line 49 of file Forest.cc.

References setTrainingEvents().

50 {
51  setTrainingEvents(trainingEvents);
52 }
void setTrainingEvents(std::vector< Event * > &trainingEvents)
Definition: Forest.cc:100
Forest::~Forest ( )

Definition at line 58 of file Forest.cc.

References mps_fire::i, and trees.

59 {
60 // When the forest is destroyed it will delete the trees as well as the
61 // events from the training and testing sets.
62 // The user may want the events to remain after they destroy the forest
63 // this should be changed in future upgrades.
64 
65  for(unsigned int i=0; i < trees.size(); i++)
66  {
67  if(trees[i]) delete trees[i];
68  }
69 }
std::vector< Tree * > trees
Definition: Forest.h:68
Forest::Forest ( const Forest forest)

Definition at line 71 of file Forest.cc.

References create_public_lumi_plots::transform, compare::tree, and trees.

72 {
73  transform(forest.trees.cbegin(),
74  forest.trees.cend(),
75  back_inserter(trees),
76  [] (const Tree *tree) { return new Tree(*tree); }
77  );
78 }
std::vector< Tree * > trees
Definition: Forest.h:68
Definition: tree.py:1
emtf::Forest::Forest ( Forest &&  forest)
default

Member Function Documentation

void Forest::appendCorrection ( std::vector< Event * > &  eventsp,
int  treenum 
)

Definition at line 464 of file Forest.cc.

References emtf::Tree::filterEvents(), trees, and updateEvents().

Referenced by predictEvent(), and predictEvents().

465 {
466 // Update the prediction by appending the next correction.
467 
468  Tree* tree = trees[treenum];
469  tree->filterEvents(eventsp);
470 
471  // Update the events with their new prediction.
472  updateEvents(tree);
473 }
void updateEvents(Tree *tree)
Definition: Forest.cc:364
std::vector< Tree * > trees
Definition: Forest.h:68
Definition: tree.py:1
void filterEvents(std::vector< Event * > &tEvents)
Definition: Tree.cc:264
void Forest::appendCorrection ( Event e,
int  treenum 
)

Definition at line 505 of file Forest.cc.

References emtf::Tree::filterEvent(), trackingPlots::fit, emtf::Node::getFitValue(), emtf::Event::predictedValue, and trees.

506 {
507 // Update the prediction by appending the next correction.
508 
509  Tree* tree = trees[treenum];
510  Node* terminalNode = tree->filterEvent(e);
511 
512  // Update the event with its new prediction.
513  double fit = terminalNode->getFitValue();
514  e->predictedValue += fit;
515 }
double getFitValue()
Definition: Node.cc:158
double predictedValue
Definition: Event.h:21
std::vector< Tree * > trees
Definition: Forest.h:68
Node * filterEvent(Event *e)
Definition: Tree.cc:298
Definition: tree.py:1
void Forest::doRegression ( int  nodeLimit,
int  treeLimit,
double  learningRate,
LossFunction l,
const char *  savetreesdirectory,
bool  saveTrees 
)

Definition at line 394 of file Forest.cc.

References emtf::Tree::buildTree(), EnergyCorrector::c, events, mps_fire::i, alignCSCRings::s, emtf::Tree::saveToXML(), sortEventVectors(), AlCaHLTBitMon_QueryRunRegistry::string, trees, and updateRegTargets().

395 {
396 // Build the forest using the training sample.
397 
398  //std::cout << std::endl << "--Building Forest..." << std::endl << std::endl;
399 
400  // The trees work with a matrix of events where the rows have the same set of events. Each row however
401  // is sorted according to the feature variable given by event->data[row].
402  // If we only had one set of events we would have to sort it according to the
403  // feature variable every time we want to calculate the best split point for that feature.
404  // By keeping sorted copies we avoid the sorting operation during splint point calculation
405  // and save computation time. If we do not sort each of the rows the regression will fail.
406  //std::cout << "Sorting event vectors..." << std::endl;
408 
409  // See how long the regression takes.
410  TStopwatch timer;
411  timer.Start(kTRUE);
412 
413  for(unsigned int i=0; i< (unsigned) treeLimit; i++)
414  {
415  // std::cout << "++Building Tree " << i << "... " << std::endl;
416  Tree* tree = new Tree(events);
417  trees.push_back(tree);
418  tree->buildTree(nodeLimit);
419 
420  // Update the targets for the next tree to fit.
421  updateRegTargets(tree, learningRate, l);
422 
423  // Save trees to xml in some directory.
424  std::ostringstream ss;
425  ss << savetreesdirectory << "/" << i << ".xml";
426  std::string s = ss.str();
427  const char* c = s.c_str();
428 
429  if(saveTrees) tree->saveToXML(c);
430  }
431  //std::cout << std::endl;
432  //std::cout << std::endl << "Done." << std::endl << std::endl;
433 
434 // std::cout << std::endl << "Total calculation time: " << timer.RealTime() << std::endl;
435 }
void buildTree(int nodeLimit)
Definition: Tree.cc:203
void sortEventVectors(std::vector< std::vector< Event * > > &e)
Definition: Forest.cc:203
void updateRegTargets(Tree *tree, double learningRate, LossFunction *l)
Definition: Forest.cc:322
std::vector< Tree * > trees
Definition: Forest.h:68
void saveToXML(const char *filename)
Definition: Tree.cc:428
Definition: tree.py:1
std::vector< std::vector< Event * > > events
Definition: Forest.h:66
void Forest::doStochasticRegression ( int  nodeLimit,
int  treeLimit,
double  learningRate,
double  fraction,
LossFunction l 
)

Definition at line 597 of file Forest.cc.

References EnergyCorrector::c, events, mps_fire::i, prepareRandomSubsample(), alignCSCRings::s, sortEventVectors(), AlCaHLTBitMon_QueryRunRegistry::string, subSample, trees, and updateRegTargets().

598 {
599 // If the fraction of events to use is one then this algorithm is slower than doRegression due to the fact
600 // that we have to sort the events every time we extract a subsample. Without random sampling we simply
601 // use all of the events and keep them sorted.
602 
603 // Anyways, this algorithm uses a portion of the events to train each tree. All of the events are updated
604 // afterwards with the results from the subsample built tree.
605 
606  // Prepare some things.
608  trees = std::vector<Tree*>(treeLimit);
609 
610  // See how long the regression takes.
611  TStopwatch timer;
612  timer.Start(kTRUE);
613 
614  // Output the current settings.
615  // std::cout << std::endl << "Running stochastic regression ... " << std::endl;
616  //std::cout << "# Nodes: " << nodeLimit << std::endl;
617  //std::cout << "Learning Rate: " << learningRate << std::endl;
618  //std::cout << "Bagging Fraction: " << fraction << std::endl;
619  //std::cout << std::endl;
620 
621 
622  for(unsigned int i=0; i< (unsigned) treeLimit; i++)
623  {
624  // Build the tree using a random subsample.
626  trees[i] = new Tree(subSample);
627  trees[i]->buildTree(nodeLimit);
628 
629  // Fit all of the events based upon the tree we built using
630  // the subsample of events.
631  trees[i]->filterEvents(events[0]);
632 
633  // Update the targets for the next tree to fit.
634  updateRegTargets(trees[i], learningRate, l);
635 
636  // Save trees to xml in some directory.
637  std::ostringstream ss;
638  ss << "trees/" << i << ".xml";
639  std::string s = ss.str();
640  const char* c = s.c_str();
641 
642  trees[i]->saveToXML(c);
643  }
644 
645  //std::cout << std::endl << "Done." << std::endl << std::endl;
646 
647  //std::cout << std::endl << "Total calculation time: " << timer.RealTime() << std::endl;
648 }
void prepareRandomSubsample(double fraction)
Definition: Forest.cc:568
std::vector< std::vector< Event * > > subSample
Definition: Forest.h:67
void sortEventVectors(std::vector< std::vector< Event * > > &e)
Definition: Forest.cc:203
void updateRegTargets(Tree *tree, double learningRate, LossFunction *l)
Definition: Forest.cc:322
std::vector< Tree * > trees
Definition: Forest.h:68
std::vector< std::vector< Event * > > events
Definition: Forest.h:66
void emtf::Forest::generate ( int  numTrainEvents,
int  numTestEvents,
double  sigma 
)
std::vector< Event * > Forest::getTrainingEvents ( )

Definition at line 122 of file Forest.cc.

References events.

122 { return events[0]; }
std::vector< std::vector< Event * > > events
Definition: Forest.h:66
Tree * Forest::getTree ( unsigned int  i)

Definition at line 129 of file Forest.cc.

References trees.

Referenced by L1TMuonEndCapForestESProducer::produce().

130 {
131  if(/*i>=0 && */i<trees.size()) return trees[i];
132  else
133  {
134  //std::cout << i << "is an invalid input for getTree. Out of range." << std::endl;
135  return nullptr;
136  }
137 }
std::vector< Tree * > trees
Definition: Forest.h:68
void Forest::listEvents ( std::vector< std::vector< Event * > > &  e)

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

Definition at line 159 of file Forest.cc.

References gather_cfg::cout, MillePedeFileConverter_cfg::e, mps_fire::i, and emtf::Event::sortingIndex.

160 {
161 // Simply list the events in each event vector. We have multiple copies
162 // of the events vector. Each copy is sorted according to a different
163 // determining variable.
164  std::cout << std::endl << "Listing Events... " << std::endl;
165 
166  for(unsigned int i=0; i < e.size(); i++)
167  {
168  std::cout << std::endl << "Variable " << i << " vector contents: " << std::endl;
169  for(unsigned int j=0; j<e[i].size(); j++)
170  {
171  e[i][j]->outputEvent();
172  }
173  std::cout << std::endl;
174  }
175 }
void Forest::loadForestFromXML ( const char *  directory,
unsigned int  numTrees 
)

Definition at line 520 of file Forest.cc.

References mps_fire::i, and trees.

521 {
522 // Load a forest that has already been created and stored into XML somewhere.
523 
524  // Initialize the vector of trees.
525  trees = std::vector<Tree*>(numTrees);
526 
527  // Load the Forest.
528  // std::cout << std::endl << "Loading Forest from XML ... " << std::endl;
529  for(unsigned int i=0; i < numTrees; i++)
530  {
531  trees[i] = new Tree();
532 
533  std::stringstream ss;
534  ss << directory << "/" << i << ".xml";
535 
536  trees[i]->loadFromXML(edm::FileInPath(ss.str().c_str()).fullPath().c_str());
537  }
538 
539  //std::cout << "Done." << std::endl << std::endl;
540 }
std::vector< Tree * > trees
Definition: Forest.h:68
void Forest::loadFromCondPayload ( const L1TMuonEndCapForest::DForest payload)

Definition at line 542 of file Forest.cc.

References mps_fire::i, and trees.

543 {
544 // Load a forest that has already been created and stored in CondDB.
545  // Initialize the vector of trees.
546  unsigned int numTrees = forest.size();
547 
548  // clean-up leftovers from previous initialization (if any)
549  for(unsigned int i=0; i < trees.size(); i++)
550  {
551  if(trees[i]) delete trees[i];
552  }
553 
554  trees = std::vector<Tree*>(numTrees);
555 
556  // Load the Forest.
557  for(unsigned int i=0; i < numTrees; i++)
558  {
559  trees[i] = new Tree();
560  trees[i]->loadFromCondPayload(forest[i]);
561  }
562 }
std::vector< Tree * > trees
Definition: Forest.h:68
Forest & Forest::operator= ( const Forest forest)

Definition at line 80 of file Forest.cc.

References mps_fire::i, create_public_lumi_plots::transform, compare::tree, and trees.

81 {
82  for(unsigned int i=0; i < trees.size(); i++)
83  {
84  if(trees[i]) delete trees[i];
85  }
86  trees.resize(0);
87 
88  transform(forest.trees.cbegin(),
89  forest.trees.cend(),
90  back_inserter(trees),
91  [] (const Tree *tree) { return new Tree(*tree); }
92  );
93  return *this;
94 }
std::vector< Tree * > trees
Definition: Forest.h:68
Definition: tree.py:1
void Forest::predictEvent ( Event e,
unsigned int  trees 
)

Definition at line 479 of file Forest.cc.

References appendCorrection(), mps_fire::i, emtf::Event::predictedValue, and trees.

Referenced by PtAssignmentEngine2016::calculate_pt_xml(), and PtAssignmentEngine2017::calculate_pt_xml().

480 {
481 // Predict values for eventsp by running them through the forest up to numtrees.
482 
483  //std::cout << "Using " << numtrees << " trees from the forest to predict events ... " << std::endl;
484  if(numtrees > trees.size())
485  {
486  //std::cout << std::endl << "!! Input greater than the forest size. Using forest.size() = " << trees.size() << " to predict instead." << std::endl;
487  numtrees = trees.size();
488  }
489 
490  // just like in line #2470 of https://root.cern.ch/doc/master/MethodBDT_8cxx_source.html for gradient boosting
491  e->predictedValue = trees[0]->getBoostWeight();
492 
493  // i iterates through the trees in the forest. Each tree corrects the last prediction.
494  for(unsigned int i=0; i < numtrees; i++)
495  {
496  //std::cout << "++Tree " << i << "..." << std::endl;
497  appendCorrection(e, i);
498  }
499 }
void appendCorrection(std::vector< Event * > &eventsp, int treenum)
Definition: Forest.cc:464
double predictedValue
Definition: Event.h:21
std::vector< Tree * > trees
Definition: Forest.h:68
void Forest::predictEvents ( std::vector< Event * > &  eventsp,
unsigned int  trees 
)

Definition at line 441 of file Forest.cc.

References appendCorrection(), mps_fire::i, and trees.

442 {
443 // Predict values for eventsp by running them through the forest up to numtrees.
444 
445  //std::cout << "Using " << numtrees << " trees from the forest to predict events ... " << std::endl;
446  if(numtrees > trees.size())
447  {
448  //std::cout << std::endl << "!! Input greater than the forest size. Using forest.size() = " << trees.size() << " to predict instead." << std::endl;
449  numtrees = trees.size();
450  }
451 
452  // i iterates through the trees in the forest. Each tree corrects the last prediction.
453  for(unsigned int i=0; i < numtrees; i++)
454  {
455  //std::cout << "++Tree " << i << "..." << std::endl;
456  appendCorrection(eventsp, i);
457  }
458 }
void appendCorrection(std::vector< Event * > &eventsp, int treenum)
Definition: Forest.cc:464
std::vector< Tree * > trees
Definition: Forest.h:68
void Forest::prepareRandomSubsample ( double  fraction)

Definition at line 568 of file Forest.cc.

References begin, end, events, mps_fire::i, emtf::shuffle(), sortEventVectors(), subSample, and findQualityFiles::v.

Referenced by doStochasticRegression().

569 {
570 // We use this for Stochastic Gradient Boosting. Basically you
571 // take a subsample of the training events and build a tree using
572 // those. Then use the tree built from the subsample to update
573 // the predictions for all the events.
574 
575  subSample = std::vector< std::vector<Event*> >(events.size()) ;
576  size_t subSampleSize = fraction*events[0].size();
577 
578  // Randomize the first subSampleSize events in events[0].
579  shuffle(events[0].begin(), events[0].end(), subSampleSize);
580 
581  // Get a copy of the random subset we just made.
582  std::vector<Event*> v(events[0].begin(), events[0].begin()+subSampleSize);
583 
584  // Initialize and sort the subSample collection.
585  for(unsigned int i=0; i<subSample.size(); i++)
586  {
587  subSample[i] = v;
588  }
589 
591 }
std::vector< std::vector< Event * > > subSample
Definition: Forest.h:67
void sortEventVectors(std::vector< std::vector< Event * > > &e)
Definition: Forest.cc:203
#define end
Definition: vmac.h:37
#define begin
Definition: vmac.h:30
std::vector< std::vector< Event * > > events
Definition: Forest.h:66
bidiiter shuffle(bidiiter begin, bidiiter end, size_t num_random)
Definition: Utilities.h:27
void Forest::rankVariables ( std::vector< int > &  rank)

Definition at line 219 of file Forest.cc.

References events, mps_fire::i, SiStripPI::max, edm::second(), trees, findQualityFiles::v, and w.

220 {
221 // This function ranks the determining variables according to their importance
222 // in determining the fit. Use a low learning rate for better results.
223 // Separates completely useless variables from useful ones well,
224 // but isn't the best at separating variables of similar importance.
225 // This is calculated using the error reduction on the training set. The function
226 // should be changed to use the testing set, but this works fine for now.
227 // I will try to change this in the future.
228 
229  // Initialize the vector v, which will store the total error reduction
230  // for each variable i in v[i].
231  std::vector<double> v(events.size(), 0);
232 
233  //std::cout << std::endl << "Ranking Variables by Net Error Reduction... " << std::endl;
234 
235  for(unsigned int j=0; j < trees.size(); j++)
236  {
237  trees[j]->rankVariables(v);
238  }
239 
240  double max = *std::max_element(v.begin(), v.end());
241 
242  // Scale the importance. Maximum importance = 100.
243  for(unsigned int i=0; i < v.size(); i++)
244  {
245  v[i] = 100*v[i]/max;
246  }
247 
248  // Change the storage format so that we can keep the index
249  // and the value associated after sorting.
250  std::vector< std::pair<double, int> > w(events.size());
251 
252  for(unsigned int i=0; i<v.size(); i++)
253  {
254  w[i] = std::pair<double, int>(v[i],i);
255  }
256 
257  // Sort so that we can output in order of importance.
258  std::sort(w.begin(),w.end());
259 
260  // Output the results.
261  for(int i=(v.size()-1); i>=0; i--)
262  {
263  rank.push_back(w[i].second);
264  // std::cout << "x" << w[i].second << ": " << w[i].first << std::endl;
265  }
266 
267  // std::cout << std::endl << "Done." << std::endl << std::endl;
268 }
const double w
Definition: UKUtility.cc:23
U second(std::pair< T, U > const &p)
std::vector< Tree * > trees
Definition: Forest.h:68
std::vector< std::vector< Event * > > events
Definition: Forest.h:66
void Forest::saveSplitValues ( const char *  savefilename)

Definition at line 274 of file Forest.cc.

References begin, end, events, mps_fire::i, trees, tier0::unique(), and findQualityFiles::v.

275 {
276 // This function gathers all of the split values from the forest and puts them into lists.
277 
278  std::ofstream splitvaluefile;
279  splitvaluefile.open(savefilename);
280 
281  // Initialize the matrix v, which will store the list of split values
282  // for each variable i in v[i].
283  std::vector<std::vector<double>> v(events.size(), std::vector<double>());
284 
285  //std::cout << std::endl << "Gathering split values... " << std::endl;
286 
287  // Gather the split values from each tree in the forest.
288  for(unsigned int j=0; j<trees.size(); j++)
289  {
290  trees[j]->getSplitValues(v);
291  }
292 
293  // Sort the lists of split values and remove the duplicates.
294  for(unsigned int i=0; i<v.size(); i++)
295  {
296  std::sort(v[i].begin(),v[i].end());
297  v[i].erase( unique( v[i].begin(), v[i].end() ), v[i].end() );
298  }
299 
300  // Output the results after removing duplicates.
301  // The 0th variable is special and is not used for splitting, so we start at 1.
302  for(unsigned int i=1; i<v.size(); i++)
303  {
304  TString splitValues;
305  for(unsigned int j=0; j<v[i].size(); j++)
306  {
307  std::stringstream ss;
308  ss.precision(14);
309  ss << std::scientific << v[i][j];
310  splitValues+=",";
311  splitValues+=ss.str().c_str();
312  }
313 
314  splitValues=splitValues(1,splitValues.Length());
315  splitvaluefile << splitValues << std::endl << std::endl;;
316  }
317 }
def unique(seq, keepstr=True)
Definition: tier0.py:24
#define end
Definition: vmac.h:37
std::vector< Tree * > trees
Definition: Forest.h:68
#define begin
Definition: vmac.h:30
std::vector< std::vector< Event * > > events
Definition: Forest.h:66
void Forest::setTrainingEvents ( std::vector< Event * > &  trainingEvents)

Definition at line 100 of file Forest.cc.

References emtf::Event::data, MillePedeFileConverter_cfg::e, events, and mps_fire::i.

Referenced by Forest().

101 {
102 // tell the forest which events to use for training
103 
104  Event* e = trainingEvents[0];
105  // Unused variable
106  // unsigned int numrows = e->data.size();
107 
108  // Reset the events matrix.
109  events = std::vector< std::vector<Event*> >();
110 
111  for(unsigned int i=0; i<e->data.size(); i++)
112  {
113  events.push_back(trainingEvents);
114  }
115 }
std::vector< std::vector< Event * > > events
Definition: Forest.h:66
std::vector< double > data
Definition: Event.h:31
unsigned int Forest::size ( void  )

Definition at line 143 of file Forest.cc.

References trees.

Referenced by ntupleDataFormat._Collection::__iter__(), ntupleDataFormat._Collection::__len__(), and L1TMuonEndCapForestESProducer::produce().

144 {
145 // Return the number of trees in the forest.
146  return trees.size();
147 }
std::vector< Tree * > trees
Definition: Forest.h:68
void Forest::sortEventVectors ( std::vector< std::vector< Event * > > &  e)

Definition at line 203 of file Forest.cc.

References begin, compareEvents(), MillePedeFileConverter_cfg::e, end, mps_fire::i, and emtf::Event::sortingIndex.

Referenced by doRegression(), doStochasticRegression(), and prepareRandomSubsample().

204 {
205 // When a node chooses the optimum split point and split variable it needs
206 // the events to be sorted according to the variable it is considering.
207 
208  for(unsigned int i=0; i<e.size(); i++)
209  {
211  std::sort(e[i].begin(), e[i].end(), compareEvents);
212  }
213 }
#define end
Definition: vmac.h:37
bool compareEvents(Event *e1, Event *e2)
Definition: Forest.cc:185
static int sortingIndex
Definition: Event.h:29
#define begin
Definition: vmac.h:30
void Forest::updateEvents ( Tree tree)

Definition at line 364 of file Forest.cc.

References MillePedeFileConverter_cfg::e, trackingPlots::fit, emtf::Tree::getTerminalNodes(), emtf::Event::predictedValue, and findQualityFiles::v.

Referenced by appendCorrection().

365 {
366 // Prepare the test events for the next tree.
367 
368  // Get the list of terminal nodes for this tree.
369  std::list<Node*>& tn = tree->getTerminalNodes();
370 
371  // Loop through the terminal nodes.
372  for(std::list<Node*>::iterator it=tn.begin(); it!=tn.end(); it++)
373  {
374  std::vector<Event*>& v = (*it)->getEvents()[0];
375  double fit = (*it)->getFitValue();
376 
377  // Loop through each event in the terminal region and update the
378  // the global event it maps to.
379  for(unsigned int j=0; j<v.size(); j++)
380  {
381  Event* e = v[j];
382  e->predictedValue += fit;
383  }
384 
385  // Release memory.
386  (*it)->getEvents() = std::vector< std::vector<Event*> >();
387  }
388 }
double predictedValue
Definition: Event.h:21
std::list< Node * > & getTerminalNodes()
Definition: Tree.cc:171
void Forest::updateRegTargets ( Tree tree,
double  learningRate,
LossFunction l 
)

Definition at line 322 of file Forest.cc.

References emtf::Event::data, MillePedeFileConverter_cfg::e, emtf::LossFunction::fit(), trackingPlots::fit, emtf::Tree::getTerminalNodes(), emtf::Event::predictedValue, emtf::LossFunction::target(), and findQualityFiles::v.

Referenced by doRegression(), and doStochasticRegression().

323 {
324 // Prepare the global vector of events for the next tree.
325 // Update the fit for each event and set the new target value
326 // for the next tree.
327 
328  // Get the list of terminal nodes for this tree.
329  std::list<Node*>& tn = tree->getTerminalNodes();
330 
331  // Loop through the terminal nodes.
332  for(std::list<Node*>::iterator it=tn.begin(); it!=tn.end(); it++)
333  {
334  // Get the events in the current terminal region.
335  std::vector<Event*>& v = (*it)->getEvents()[0];
336 
337  // Fit the events depending on the loss function criteria.
338  double fit = l->fit(v);
339 
340  // Scale the rate at which the algorithm converges.
341  fit = learningRate*fit;
342 
343  // Store the official fit value in the terminal node.
344  (*it)->setFitValue(fit);
345 
346  // Loop through each event in the terminal region and update the
347  // the target for the next tree.
348  for(unsigned int j=0; j<v.size(); j++)
349  {
350  Event* e = v[j];
351  e->predictedValue += fit;
352  e->data[0] = l->target(e);
353  }
354 
355  // Release memory.
356  (*it)->getEvents() = std::vector< std::vector<Event*> >();
357  }
358 }
virtual double fit(std::vector< Event * > &v)=0
double predictedValue
Definition: Event.h:21
std::list< Node * > & getTerminalNodes()
Definition: Tree.cc:171
virtual double target(Event *e)=0
std::vector< double > data
Definition: Event.h:31

Member Data Documentation

std::vector< std::vector<Event*> > emtf::Forest::events
private
std::vector< std::vector<Event*> > emtf::Forest::subSample
private

Definition at line 67 of file Forest.h.

Referenced by doStochasticRegression(), and prepareRandomSubsample().

std::vector<Tree*> emtf::Forest::trees
private