CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_5_2_9/src/TrackingTools/TrajectoryCleaning/src/TrajectoryCleanerMerger.cc

Go to the documentation of this file.
00001 #include "TrackingTools/TrajectoryCleaning/interface/TrajectoryCleanerMerger.h"
00002 #include "TrackingTools/TransientTrackingRecHit/interface/TransientTrackingRecHit.h"
00003 
00004 #include <map>
00005 #include <vector>
00006 
00007 using namespace std;
00008 
00009 #include "DataFormats/TrackingRecHit/interface/TrackingRecHit.h"
00010 
00011 #include "DataFormats/TrackerRecHit2D/interface/SiPixelRecHit.h"
00012 #include "DataFormats/TrackerRecHit2D/interface/SiStripRecHit2D.h"
00013 #include "DataFormats/TrackerRecHit2D/interface/SiStripMatchedRecHit2D.h"
00014 #include "DataFormats/TrackerRecHit2D/interface/ProjectedSiStripRecHit2D.h"
00015 
00016 #include "DataFormats/SiPixelDetId/interface/PXBDetId.h"
00017 #include "DataFormats/SiPixelDetId/interface/PXFDetId.h"
00018 
00019 #include "DataFormats/SiStripDetId/interface/TIBDetId.h"
00020 #include "DataFormats/SiStripDetId/interface/TOBDetId.h"
00021 #include "DataFormats/SiStripDetId/interface/TIDDetId.h"
00022 #include "DataFormats/SiStripDetId/interface/TECDetId.h"
00023 
00024 #include "DataFormats/SiPixelDetId/interface/PixelSubdetector.h"
00025 #include "DataFormats/SiStripDetId/interface/StripSubdetector.h"
00026 
00027 #include <fstream>
00028 
00029 /*****************************************************************************/
00030 class HitComparator
00031 {
00032   public:
00033     bool operator() (const TransientTrackingRecHit* ta,
00034                      const TransientTrackingRecHit* tb) const
00035     {
00036       const TrackingRecHit* a = ta->hit();
00037       const TrackingRecHit* b = tb->hit();
00038 
00039       if(getId(a) < getId(b)) return true;
00040       if(getId(b) < getId(a)) return false;
00041 
00042       if(a->geographicalId() < b->geographicalId()) return true;
00043       if(b->geographicalId() < a->geographicalId()) return false;
00044 
00045       const SiPixelRecHit* a_ = dynamic_cast<const SiPixelRecHit*>(a);
00046       if(a_ != 0)
00047       {
00048         const SiPixelRecHit* b_ = dynamic_cast<const SiPixelRecHit*>(b);
00049         return less(a_, b_);
00050       }
00051       else
00052       {
00053         const SiStripMatchedRecHit2D* a_ =
00054           dynamic_cast<const SiStripMatchedRecHit2D*>(a);
00055 
00056         if(a_ != 0)
00057         {
00058           const SiStripMatchedRecHit2D* b_ =
00059             dynamic_cast<const SiStripMatchedRecHit2D*>(b);
00060           return less(a_, b_);
00061         }
00062         else
00063         {
00064           const SiStripRecHit2D* a_ =
00065             dynamic_cast<const SiStripRecHit2D*>(a);
00066 
00067           if(a_ != 0)
00068           {
00069             const SiStripRecHit2D* b_ =
00070               dynamic_cast<const SiStripRecHit2D*>(b);
00071             return less(a_, b_);
00072           }
00073           else 
00074           {
00075             const ProjectedSiStripRecHit2D* a_ =
00076               dynamic_cast<const ProjectedSiStripRecHit2D*>(a); 
00077 
00078 //std::cerr << " comp proj" << std::endl;
00079 
00080             if(a_ != 0)
00081             {
00082               const ProjectedSiStripRecHit2D* b_ =
00083                 dynamic_cast<const ProjectedSiStripRecHit2D*>(b);
00084 
00085               return less(&(a_->originalHit()), &(b_->originalHit()));
00086             }
00087             else
00088               return false;
00089           }
00090         }
00091       }
00092     }
00093 
00094   private:
00095     int getId(const TrackingRecHit* a) const
00096     {
00097       if(dynamic_cast<const SiPixelRecHit*>(a)            != 0) return 0;
00098       if(dynamic_cast<const SiStripRecHit2D*>(a)          != 0) return 1;
00099       if(dynamic_cast<const SiStripMatchedRecHit2D*>(a)   != 0) return 2;
00100       if(dynamic_cast<const ProjectedSiStripRecHit2D*>(a) != 0) return 3;
00101       return -1;
00102     }
00103 
00104     bool less(const SiPixelRecHit* a,
00105               const SiPixelRecHit* b) const
00106     {
00107 //std::cerr << " comp pixel" << std::endl;
00108       return a->cluster() < b->cluster();
00109     }
00110 
00111     bool less(const SiStripRecHit2D* a,
00112               const SiStripRecHit2D *b) const
00113     {
00114 //std::cerr << " comp strip" << std::endl;
00115       return a->cluster() < b->cluster();
00116     }
00117 
00118     bool less(const SiStripMatchedRecHit2D* a,
00119               const SiStripMatchedRecHit2D *b) const
00120     {
00121 //std::cerr << " comp matched strip" << std::endl;
00122       if(a->monoClusterRef() < b->monoClusterRef()) return true;
00123       if(b->monoClusterRef() < a->monoClusterRef()) return false;
00124       if(a->stereoClusterRef() < b->stereoClusterRef()) return true;
00125       return false;
00126     }
00127 };
00128 
00129 /*****************************************************************************/
00130 void TrajectoryCleanerMerger::clean( TrajectoryPointerContainer&) const
00131 {
00132 }
00133 
00134 /*****************************************************************************/
00135 void TrajectoryCleanerMerger::reOrderMeasurements(Trajectory& traj)const
00136 {
00137   std::vector<TrajectoryMeasurement> meas_ = traj.measurements();
00138   std::vector<TrajectoryMeasurement> meas;
00139 
00140   for(std::vector<TrajectoryMeasurement>::iterator
00141        im = meas_.begin();
00142        im!= meas_.end(); im++)
00143     if(im->recHit()->isValid())
00144        meas.push_back(*im);
00145 
00146   bool changed;
00147 
00148   do
00149   {
00150     changed = false;
00151 
00152     for(std::vector<TrajectoryMeasurement>::iterator im = meas.begin();
00153                                                 im!= meas.end()-1; im++)
00154     if(    (*im).recHit()->globalPosition().mag2() >
00155        (*(im+1)).recHit()->globalPosition().mag2() + 1e-6)
00156     {
00157       swap(*im,*(im+1));
00158       changed = true;
00159     }
00160   }
00161   while(changed);
00162 
00163   for(unsigned int i = 0 ; i < meas.size(); i++)
00164      traj.pop();
00165 
00166   for(std::vector<TrajectoryMeasurement>::iterator im = meas.begin();
00167                                               im!= meas.end(); im++)
00168     traj.push(*im);
00169 }
00170 /*****************************************************************************/
00171 bool TrajectoryCleanerMerger::sameSeed  (const TrajectorySeed & s1,   const TrajectorySeed & s2)const
00172 {
00173   if(s1.nHits() != s2.nHits()) return false;
00174 
00175   TrajectorySeed::range r1 = s1.recHits();
00176   TrajectorySeed::range r2 = s2.recHits();
00177 
00178   TrajectorySeed::const_iterator h1 = r1.first;
00179   TrajectorySeed::const_iterator h2 = r2.first;
00180 
00181   do
00182   {
00183     if(!(h1->sharesInput(&(*h2),TrackingRecHit::all)))
00184       return false;
00185 
00186     h1++; h2++;
00187   }
00188   while(h1 != s1.recHits().second && 
00189         h2 != s2.recHits().second);
00190 
00191   return true;
00192 }
00193 
00194 /*****************************************************************************/
00195 int TrajectoryCleanerMerger::getLayer(const DetId & id)const
00196 {
00197   // PXB layer, ladder -> (layer - 1)<<2 + (ladder-1)%2
00198   // PXF disk , panel
00199   // TIB layer, module 
00200   // TOB layer, module
00201   // TID wheel, ring
00202   // TEC wheel, ring
00203 
00204   if(id.subdetId() == (unsigned int) PixelSubdetector::PixelBarrel)
00205   { PXBDetId pid(id); return (100 * id.subdetId()+ ((pid.layer() - 1)<<1) + (pid.ladder() - 1)%2); }
00206 
00207   if(id.subdetId() == (unsigned int) PixelSubdetector::PixelEndcap)
00208   { PXFDetId pid(id); return (100 * id.subdetId()+ ((pid.disk()  - 1)<<1) + (pid.panel()  - 1)%2); }
00209 
00210   if(id.subdetId() == StripSubdetector::TIB)
00211   { TIBDetId pid(id); return (100 * id.subdetId()+ ((pid.layer() - 1)<<1) + (pid.module() - 1)%2); }
00212   if(id.subdetId() == StripSubdetector::TOB)
00213   { TOBDetId pid(id); return (100 * id.subdetId()+ ((pid.layer() - 1)<<1) + (pid.module() - 1)%2); }
00214 
00215   if(id.subdetId() == StripSubdetector::TID)
00216   { TIDDetId pid(id); return (100 * id.subdetId()+ ((pid.wheel() - 1)<<1) + (pid.ring()   - 1)%2); }
00217   if(id.subdetId() == StripSubdetector::TEC)
00218   { TECDetId pid(id); return (100 * id.subdetId()+ ((pid.wheel() - 1)<<1) + (pid.ring()   - 1)%2); }
00219 
00220   return 0;
00221 }
00222 
00223 /***************************************************************************/
00224 
00225 void TrajectoryCleanerMerger::clean
00226   (TrajectoryContainer& trajs) const 
00227 {
00228   if(trajs.size() == 0) return;
00229 
00230   // Fill the rechit map
00231   typedef std::map<const TransientTrackingRecHit*,
00232               std::vector<unsigned int>, HitComparator> RecHitMap; 
00233   RecHitMap recHitMap;
00234 
00235   std::vector<bool> keep(trajs.size(),true);
00236 
00237   for(unsigned int i = 0; i < trajs.size(); i++) 
00238   {
00239     std::vector<TrajectoryMeasurement> meas = trajs[i].measurements();
00240 
00241     for(std::vector<TrajectoryMeasurement>::iterator im = meas.begin();
00242                                                 im!= meas.end(); im++)
00243       if(im->recHit()->isValid())
00244       {
00245         const TransientTrackingRecHit* recHit = &(*(im->recHit()));
00246         if(recHit->isValid())
00247           recHitMap[recHit].push_back(i);
00248       }
00249   }
00250 
00251   // Look at each track
00252   typedef std::map<unsigned int,int,less<unsigned int> > TrajMap;
00253 
00254   for(unsigned int i = 0; i < trajs.size(); i++)
00255   if(keep[i])
00256   {  
00257     TrajMap trajMap;
00258     std::vector<DetId> detIds;
00259     std::vector<int> detLayers;
00260 
00261     // Go trough all rechits of this track
00262     std::vector<TrajectoryMeasurement> meas = trajs[i].measurements();
00263     for(std::vector<TrajectoryMeasurement>::iterator im = meas.begin();
00264                                                 im!= meas.end(); im++)
00265     {
00266       if(im->recHit()->isValid())
00267       {
00268         // Get trajs sharing this rechit
00269         const TransientTrackingRecHit* recHit = &(*(im->recHit()));
00270         const std::vector<unsigned int>& sharing(recHitMap[recHit]);
00271 
00272         for(std::vector<unsigned int>::const_iterator j = sharing.begin(); 
00273                                                  j!= sharing.end(); j++)
00274           if(i < *j) trajMap[*j]++;
00275 
00276         // Fill detLayers vector
00277         detIds.push_back(recHit->geographicalId());
00278         detLayers.push_back(getLayer(recHit->geographicalId()));
00279       }
00280     }
00281 
00282     // Check for trajs with shared rechits
00283     for(TrajMap::iterator sharing = trajMap.begin();
00284                           sharing!= trajMap.end(); sharing++)
00285     {
00286       unsigned int j = (*sharing).first;
00287       if(!keep[i] || !keep[j]) continue;
00288 
00289       // More than 50% shared
00290       if((*sharing).second > min(trajs[i].foundHits(),
00291                                  trajs[j].foundHits())/2)
00292       {
00293         if( sameSeed(trajs[i].seed(), trajs[j].seed()) )
00294         {
00295         bool hasCommonLayer = false;
00296 
00297 /*
00298         std::vector<TrajectoryMeasurement> measi = trajs[i].measurements();
00299         std::vector<TrajectoryMeasurement> measj = trajs[j].measurements();
00300         for(std::vector<TrajectoryMeasurement>::iterator
00301               tmj = measj.begin(); tmj!= measj.end(); tmj++)
00302             if(find(measi.begin(), measi.end(), tmj) == measi.end())
00303             if(find(detLayers.begin(),detLayers.end(),
00304                     getLayer(tmj->recHit()->geographicalId()))
00305                                    != detLayers.end())
00306              hasCommonLayer = true;
00307 */
00308 
00309         if(hasCommonLayer == false)
00310         { // merge tracks, add separate hits of the second to the first one
00311         std::vector<TrajectoryMeasurement> measj = trajs[j].measurements();
00312         for(std::vector<TrajectoryMeasurement>::iterator
00313              tmj = measj.begin(); tmj!= measj.end(); tmj++)
00314         if(tmj->recHit()->isValid())
00315         {
00316           bool match = false;
00317 
00318           std::vector<TrajectoryMeasurement> measi = trajs[i].measurements();
00319           for(std::vector<TrajectoryMeasurement>::iterator
00320              tmi = measi.begin(); tmi!= measi.end(); tmi++)
00321           if(tmi->recHit()->isValid())
00322             if(!HitComparator()(&(*(tmi->recHit())),
00323                                 &(*(tmj->recHit()))) &&
00324                !HitComparator()(&(*(tmj->recHit())),
00325                                 &(*(tmi->recHit()))))
00326             { match = true ; break; }
00327 
00328           if(!match)
00329             trajs[i].push(*tmj);
00330         }
00331 
00332         // Remove second track
00333         keep[j] = false;
00334         }
00335         else
00336         {
00337           // remove track with higher impact / chi2
00338           if(trajs[i].chiSquared() < trajs[j].chiSquared())
00339             keep[j] = false;
00340           else
00341             keep[i] = false;
00342         }
00343         }
00344       }
00345     } 
00346   }
00347 
00348   // Final copy
00349   int ok = 0;
00350   for(unsigned int i = 0; i < trajs.size(); i++)
00351     if(keep[i])
00352     {
00353       reOrderMeasurements(trajs[i]);
00354       ok++;
00355     }
00356     else
00357       trajs[i].invalidate();
00358 
00359   std::cerr << " [TrajecCleaner] cleaned trajs : " << ok << "/" << trajs.size() <<
00360 " (with " << trajs[0].measurements().size() << "/" << recHitMap.size() << " hits)" << std::endl;
00361 }
00362