CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_4_1_8_patch9/src/TrackingTools/TrajectoryCleaning/src/TrajectoryCleanerMerger.cc

Go to the documentation of this file.
00001 #include "TrackingTools/TrajectoryCleaning/interface/TrajectoryCleanerMerger.h"
00002 #include "TrackingTools/TransientTrackingRecHit/interface/TransientTrackingRecHit.h"
00003 
00004 #include <map>
00005 #include <vector>
00006 
00007 using namespace std;
00008 
00009 #include "DataFormats/TrackingRecHit/interface/TrackingRecHit.h"
00010 
00011 #include "DataFormats/TrackerRecHit2D/interface/SiPixelRecHit.h"
00012 #include "DataFormats/TrackerRecHit2D/interface/SiStripRecHit2D.h"
00013 #include "DataFormats/TrackerRecHit2D/interface/SiStripMatchedRecHit2D.h"
00014 #include "DataFormats/TrackerRecHit2D/interface/ProjectedSiStripRecHit2D.h"
00015 
00016 #include "DataFormats/SiPixelDetId/interface/PXBDetId.h"
00017 #include "DataFormats/SiPixelDetId/interface/PXFDetId.h"
00018 
00019 #include "DataFormats/SiStripDetId/interface/TIBDetId.h"
00020 #include "DataFormats/SiStripDetId/interface/TOBDetId.h"
00021 #include "DataFormats/SiStripDetId/interface/TIDDetId.h"
00022 #include "DataFormats/SiStripDetId/interface/TECDetId.h"
00023 
00024 #include "DataFormats/SiPixelDetId/interface/PixelSubdetector.h"
00025 #include "DataFormats/SiStripDetId/interface/StripSubdetector.h"
00026 
00027 #include <fstream>
00028 
00029 /*****************************************************************************/
00030 class HitComparator
00031 {
00032   public:
00033     bool operator() (const TransientTrackingRecHit* ta,
00034                      const TransientTrackingRecHit* tb) const
00035     {
00036       const TrackingRecHit* a = ta->hit();
00037       const TrackingRecHit* b = tb->hit();
00038 
00039       if(getId(a) < getId(b)) return true;
00040       if(getId(b) < getId(a)) return false;
00041 
00042       if(a->geographicalId() < b->geographicalId()) return true;
00043       if(b->geographicalId() < a->geographicalId()) return false;
00044 
00045       const SiPixelRecHit* a_ = dynamic_cast<const SiPixelRecHit*>(a);
00046       if(a_ != 0)
00047       {
00048         const SiPixelRecHit* b_ = dynamic_cast<const SiPixelRecHit*>(b);
00049         return less(a_, b_);
00050       }
00051       else
00052       {
00053         const SiStripMatchedRecHit2D* a_ =
00054           dynamic_cast<const SiStripMatchedRecHit2D*>(a);
00055 
00056         if(a_ != 0)
00057         {
00058           const SiStripMatchedRecHit2D* b_ =
00059             dynamic_cast<const SiStripMatchedRecHit2D*>(b);
00060           return less(a_, b_);
00061         }
00062         else
00063         {
00064           const SiStripRecHit2D* a_ =
00065             dynamic_cast<const SiStripRecHit2D*>(a);
00066 
00067           if(a_ != 0)
00068           {
00069             const SiStripRecHit2D* b_ =
00070               dynamic_cast<const SiStripRecHit2D*>(b);
00071             return less(a_, b_);
00072           }
00073           else 
00074           {
00075             const ProjectedSiStripRecHit2D* a_ =
00076               dynamic_cast<const ProjectedSiStripRecHit2D*>(a); 
00077 
00078 //std::cerr << " comp proj" << std::endl;
00079 
00080             if(a_ != 0)
00081             {
00082               const ProjectedSiStripRecHit2D* b_ =
00083                 dynamic_cast<const ProjectedSiStripRecHit2D*>(b);
00084 
00085               return less(&(a_->originalHit()), &(b_->originalHit()));
00086             }
00087             else
00088               return false;
00089           }
00090         }
00091       }
00092     }
00093 
00094   private:
00095     int getId(const TrackingRecHit* a) const
00096     {
00097       if(dynamic_cast<const SiPixelRecHit*>(a)            != 0) return 0;
00098       if(dynamic_cast<const SiStripRecHit2D*>(a)          != 0) return 1;
00099       if(dynamic_cast<const SiStripMatchedRecHit2D*>(a)   != 0) return 2;
00100       if(dynamic_cast<const ProjectedSiStripRecHit2D*>(a) != 0) return 3;
00101       return -1;
00102     }
00103 
00104     bool less(const SiPixelRecHit* a,
00105               const SiPixelRecHit* b) const
00106     {
00107 //std::cerr << " comp pixel" << std::endl;
00108       return a->cluster() < b->cluster();
00109     }
00110 
00111     bool less(const SiStripRecHit2D* a,
00112               const SiStripRecHit2D *b) const
00113     {
00114 //std::cerr << " comp strip" << std::endl;
00115       return a->cluster() < b->cluster();
00116     }
00117 
00118     bool less(const SiStripMatchedRecHit2D* a,
00119               const SiStripMatchedRecHit2D *b) const
00120     {
00121 //std::cerr << " comp matched strip" << std::endl;
00122       if(less(a->monoHit(), b->monoHit())) return true;
00123       if(less(b->monoHit(), a->monoHit())) return false;
00124 
00125       if(less(a->stereoHit(), b->stereoHit())) return true;
00126       return false;
00127     }
00128 };
00129 
00130 /*****************************************************************************/
00131 void TrajectoryCleanerMerger::clean( TrajectoryPointerContainer&) const
00132 {
00133 }
00134 
00135 /*****************************************************************************/
00136 void TrajectoryCleanerMerger::reOrderMeasurements(Trajectory& traj)const
00137 {
00138   std::vector<TrajectoryMeasurement> meas_ = traj.measurements();
00139   std::vector<TrajectoryMeasurement> meas;
00140 
00141   for(std::vector<TrajectoryMeasurement>::iterator
00142        im = meas_.begin();
00143        im!= meas_.end(); im++)
00144     if(im->recHit()->isValid())
00145        meas.push_back(*im);
00146 
00147   bool changed;
00148 
00149   do
00150   {
00151     changed = false;
00152 
00153     for(std::vector<TrajectoryMeasurement>::iterator im = meas.begin();
00154                                                 im!= meas.end()-1; im++)
00155     if(    (*im).recHit()->globalPosition().mag2() >
00156        (*(im+1)).recHit()->globalPosition().mag2() + 1e-6)
00157     {
00158       swap(*im,*(im+1));
00159       changed = true;
00160     }
00161   }
00162   while(changed);
00163 
00164   for(unsigned int i = 0 ; i < meas.size(); i++)
00165      traj.pop();
00166 
00167   for(std::vector<TrajectoryMeasurement>::iterator im = meas.begin();
00168                                               im!= meas.end(); im++)
00169     traj.push(*im);
00170 }
00171 /*****************************************************************************/
00172 bool TrajectoryCleanerMerger::sameSeed  (const TrajectorySeed & s1,   const TrajectorySeed & s2)const
00173 {
00174   if(s1.nHits() != s2.nHits()) return false;
00175 
00176   TrajectorySeed::range r1 = s1.recHits();
00177   TrajectorySeed::range r2 = s2.recHits();
00178 
00179   TrajectorySeed::const_iterator h1 = r1.first;
00180   TrajectorySeed::const_iterator h2 = r2.first;
00181 
00182   do
00183   {
00184     if(!(h1->sharesInput(&(*h2),TrackingRecHit::all)))
00185       return false;
00186 
00187     h1++; h2++;
00188   }
00189   while(h1 != s1.recHits().second && 
00190         h2 != s2.recHits().second);
00191 
00192   return true;
00193 }
00194 
00195 /*****************************************************************************/
00196 int TrajectoryCleanerMerger::getLayer(const DetId & id)const
00197 {
00198   // PXB layer, ladder -> (layer - 1)<<2 + (ladder-1)%2
00199   // PXF disk , panel
00200   // TIB layer, module 
00201   // TOB layer, module
00202   // TID wheel, ring
00203   // TEC wheel, ring
00204 
00205   if(id.subdetId() == (unsigned int) PixelSubdetector::PixelBarrel)
00206   { PXBDetId pid(id); return (100 * id.subdetId()+ ((pid.layer() - 1)<<1) + (pid.ladder() - 1)%2); }
00207 
00208   if(id.subdetId() == (unsigned int) PixelSubdetector::PixelEndcap)
00209   { PXFDetId pid(id); return (100 * id.subdetId()+ ((pid.disk()  - 1)<<1) + (pid.panel()  - 1)%2); }
00210 
00211   if(id.subdetId() == StripSubdetector::TIB)
00212   { TIBDetId pid(id); return (100 * id.subdetId()+ ((pid.layer() - 1)<<1) + (pid.module() - 1)%2); }
00213   if(id.subdetId() == StripSubdetector::TOB)
00214   { TOBDetId pid(id); return (100 * id.subdetId()+ ((pid.layer() - 1)<<1) + (pid.module() - 1)%2); }
00215 
00216   if(id.subdetId() == StripSubdetector::TID)
00217   { TIDDetId pid(id); return (100 * id.subdetId()+ ((pid.wheel() - 1)<<1) + (pid.ring()   - 1)%2); }
00218   if(id.subdetId() == StripSubdetector::TEC)
00219   { TECDetId pid(id); return (100 * id.subdetId()+ ((pid.wheel() - 1)<<1) + (pid.ring()   - 1)%2); }
00220 
00221   return 0;
00222 }
00223 
00224 /***************************************************************************/
00225 
00226 void TrajectoryCleanerMerger::clean
00227   (TrajectoryContainer& trajs) const 
00228 {
00229   if(trajs.size() == 0) return;
00230 
00231   // Fill the rechit map
00232   typedef std::map<const TransientTrackingRecHit*,
00233               std::vector<unsigned int>, HitComparator> RecHitMap; 
00234   RecHitMap recHitMap;
00235 
00236   std::vector<bool> keep(trajs.size(),true);
00237 
00238   for(unsigned int i = 0; i < trajs.size(); i++) 
00239   {
00240     std::vector<TrajectoryMeasurement> meas = trajs[i].measurements();
00241 
00242     for(std::vector<TrajectoryMeasurement>::iterator im = meas.begin();
00243                                                 im!= meas.end(); im++)
00244       if(im->recHit()->isValid())
00245       {
00246         const TransientTrackingRecHit* recHit = &(*(im->recHit()));
00247         if(recHit->isValid())
00248           recHitMap[recHit].push_back(i);
00249       }
00250   }
00251 
00252   // Look at each track
00253   typedef std::map<unsigned int,int,less<unsigned int> > TrajMap;
00254 
00255   for(unsigned int i = 0; i < trajs.size(); i++)
00256   if(keep[i])
00257   {  
00258     TrajMap trajMap;
00259     std::vector<DetId> detIds;
00260     std::vector<int> detLayers;
00261 
00262     // Go trough all rechits of this track
00263     std::vector<TrajectoryMeasurement> meas = trajs[i].measurements();
00264     for(std::vector<TrajectoryMeasurement>::iterator im = meas.begin();
00265                                                 im!= meas.end(); im++)
00266     {
00267       if(im->recHit()->isValid())
00268       {
00269         // Get trajs sharing this rechit
00270         const TransientTrackingRecHit* recHit = &(*(im->recHit()));
00271         const std::vector<unsigned int>& sharing(recHitMap[recHit]);
00272 
00273         for(std::vector<unsigned int>::const_iterator j = sharing.begin(); 
00274                                                  j!= sharing.end(); j++)
00275           if(i < *j) trajMap[*j]++;
00276 
00277         // Fill detLayers vector
00278         detIds.push_back(recHit->geographicalId());
00279         detLayers.push_back(getLayer(recHit->geographicalId()));
00280       }
00281     }
00282 
00283     // Check for trajs with shared rechits
00284     for(TrajMap::iterator sharing = trajMap.begin();
00285                           sharing!= trajMap.end(); sharing++)
00286     {
00287       unsigned int j = (*sharing).first;
00288       if(!keep[i] || !keep[j]) continue;
00289 
00290       // More than 50% shared
00291       if((*sharing).second > min(trajs[i].foundHits(),
00292                                  trajs[j].foundHits())/2)
00293       {
00294         if( sameSeed(trajs[i].seed(), trajs[j].seed()) )
00295         {
00296         bool hasCommonLayer = false;
00297 
00298 /*
00299         std::vector<TrajectoryMeasurement> measi = trajs[i].measurements();
00300         std::vector<TrajectoryMeasurement> measj = trajs[j].measurements();
00301         for(std::vector<TrajectoryMeasurement>::iterator
00302               tmj = measj.begin(); tmj!= measj.end(); tmj++)
00303             if(find(measi.begin(), measi.end(), tmj) == measi.end())
00304             if(find(detLayers.begin(),detLayers.end(),
00305                     getLayer(tmj->recHit()->geographicalId()))
00306                                    != detLayers.end())
00307              hasCommonLayer = true;
00308 */
00309 
00310         if(hasCommonLayer == false)
00311         { // merge tracks, add separate hits of the second to the first one
00312         std::vector<TrajectoryMeasurement> measj = trajs[j].measurements();
00313         for(std::vector<TrajectoryMeasurement>::iterator
00314              tmj = measj.begin(); tmj!= measj.end(); tmj++)
00315         if(tmj->recHit()->isValid())
00316         {
00317           bool match = false;
00318 
00319           std::vector<TrajectoryMeasurement> measi = trajs[i].measurements();
00320           for(std::vector<TrajectoryMeasurement>::iterator
00321              tmi = measi.begin(); tmi!= measi.end(); tmi++)
00322           if(tmi->recHit()->isValid())
00323             if(!HitComparator()(&(*(tmi->recHit())),
00324                                 &(*(tmj->recHit()))) &&
00325                !HitComparator()(&(*(tmj->recHit())),
00326                                 &(*(tmi->recHit()))))
00327             { match = true ; break; }
00328 
00329           if(!match)
00330             trajs[i].push(*tmj);
00331         }
00332 
00333         // Remove second track
00334         keep[j] = false;
00335         }
00336         else
00337         {
00338           // remove track with higher impact / chi2
00339           if(trajs[i].chiSquared() < trajs[j].chiSquared())
00340             keep[j] = false;
00341           else
00342             keep[i] = false;
00343         }
00344         }
00345       }
00346     } 
00347   }
00348 
00349   // Final copy
00350   int ok = 0;
00351   for(unsigned int i = 0; i < trajs.size(); i++)
00352     if(keep[i])
00353     {
00354       reOrderMeasurements(trajs[i]);
00355       ok++;
00356     }
00357     else
00358       trajs[i].invalidate();
00359 
00360   std::cerr << " [TrajecCleaner] cleaned trajs : " << ok << "/" << trajs.size() <<
00361 " (with " << trajs[0].measurements().size() << "/" << recHitMap.size() << " hits)" << std::endl;
00362 }
00363