00001 #include "RecoVertex/MultiVertexFit/interface/MultiVertexFitter.h"
00002
00003 #include <map>
00004 #include <algorithm>
00005 #include <iomanip>
00006
00007
00008 #include "RecoVertex/KalmanVertexFit/interface/KalmanVertexFitter.h"
00009 #include "RecoVertex/VertexTools/interface/LinearizedTrackStateFactory.h"
00010 #include "RecoVertex/VertexTools/interface/VertexTrackFactory.h"
00011 #include "RecoVertex/AdaptiveVertexFit/interface/KalmanChiSquare.h"
00012 #include "RecoVertex/VertexPrimitives/interface/VertexState.h"
00013 #include "RecoVertex/VertexPrimitives/interface/VertexException.h"
00014 #include "RecoVertex/KalmanVertexFit/interface/KalmanVertexTrackCompatibilityEstimator.h"
00015
00016
00017 #ifdef MVFHarvestingDebug
00018 #include "Vertex/VertexSimpleVis/interface/PrimitivesHarvester.h"
00019 #endif
00020
00021 using namespace std;
00022 using namespace reco;
00023
00024 namespace
00025 {
00026 typedef MultiVertexFitter::TrackAndWeight TrackAndWeight;
00027 typedef MultiVertexFitter::TrackAndSeedToWeightMap TrackAndSeedToWeightMap;
00028 typedef MultiVertexFitter::SeedToWeightMap SeedToWeightMap;
00029 typedef CachingVertex<5>::RefCountedVertexTrack RefCountedVertexTrack;
00030
00031 int verbose()
00032 {
00033 static const int ret = 0;
00034
00035 return ret;
00036 }
00037
00038 double minWeightFraction()
00039 {
00040
00041
00042
00043
00044 static const float ret = 1e-6;
00045
00046 return ret;
00047 }
00048
00049 bool discardLightWeights()
00050 {
00051 static const bool ret = true;
00052
00053 return ret;
00054 }
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 CachingVertex<5> createSeedFromLinPt ( const GlobalPoint & gp )
00067 {
00068 return CachingVertex<5> ( gp, GlobalError (),
00069 vector<RefCountedVertexTrack> (), 0.0 );
00070 }
00071
00072 double validWeight ( double weight )
00073 {
00074 if ( weight > 1.0 )
00075 {
00076 cout << "[MultiVertexFitter] weight=" << weight << "??" << endl;
00077 return 1.0;
00078 };
00079
00080 if ( weight < 0.0 )
00081 {
00082 cout << "[MultiVertexFitter] weight=" << weight << "??" << endl;
00083 return 0.0;
00084 };
00085 return weight;
00086 }
00087 }
00088
00089 void MultiVertexFitter::clear()
00090 {
00091 theAssComp->resetAnnealing();
00092 theTracks.clear();
00093 thePrimaries.clear();
00094 theVertexStates.clear();
00095 theWeights.clear();
00096 theCache.clear();
00097 }
00098
00099
00100
00101
00102 void MultiVertexFitter::createSeed( const vector < TransientTrack > & tracks )
00103 {
00104 if ( tracks.size() > 1 )
00105 {
00106 CachingVertex<5> vtx = createSeedFromLinPt (
00107 theSeeder->getLinearizationPoint ( tracks ) );
00108 int snr= seedNr();
00109 theVertexStates.push_back ( pair < int, CachingVertex<5> > ( snr, vtx ) );
00110 for ( vector< TransientTrack >::const_iterator track=tracks.begin();
00111 track!=tracks.end() ; ++track )
00112 {
00113 theWeights[*track][snr]=1.;
00114 theTracks.push_back ( *track );
00115 };
00116 };
00117 }
00118
00119 void MultiVertexFitter::createPrimaries ( const std::vector < reco::TransientTrack > tracks )
00120 {
00121
00122 for ( vector< reco::TransientTrack >::const_iterator i=tracks.begin();
00123 i!=tracks.end() ; ++i )
00124 {
00125 thePrimaries.insert ( *i );
00126
00127 }
00128
00129 }
00130
00131 int MultiVertexFitter::seedNr()
00132 {
00133 return theVertexStateNr++;
00134 }
00135
00136 void MultiVertexFitter::resetSeedNr()
00137 {
00138 theVertexStateNr=0;
00139 }
00140
00141 void MultiVertexFitter::createSeed( const vector < TrackAndWeight > & tracks )
00142 {
00143
00144 vector < RefCountedVertexTrack> newTracks;
00145
00146 for ( vector< TrackAndWeight >::const_iterator track=tracks.begin();
00147 track!=tracks.end() ; ++track )
00148 {
00149 double weight = validWeight ( track->second );
00150 const GlobalPoint & pos = track->first.impactPointState().globalPosition();
00151 GlobalError err;
00152 VertexState realseed ( pos, err );
00153
00154 RefCountedLinearizedTrackState lTrData =
00155 theCache.linTrack ( pos, track->first );
00156
00157 VertexTrackFactory<5> vTrackFactory;
00158 RefCountedVertexTrack vTrData = vTrackFactory.vertexTrack(
00159 lTrData, realseed, weight );
00160 newTracks.push_back ( vTrData );
00161 };
00162
00163 if ( newTracks.size() > 1 )
00164 {
00165 CachingVertex<5> vtx = KalmanVertexFitter().vertex ( newTracks );
00166 int snr = seedNr();
00167 theVertexStates.push_back ( pair < int, CachingVertex<5> > ( snr, vtx ) );
00168
00169
00170
00171 for ( vector< TrackAndWeight >::const_iterator track=tracks.begin();
00172 track!=tracks.end() ; ++track )
00173 {
00174 if ( thePrimaries.count ( track->first ) )
00175 {
00176
00177
00178
00179
00180
00181
00182 theWeights[track->first][theVertexStates[0].first]=track->second;
00183 continue;
00184 };
00185 float weight = track->second;
00186 if ( weight > 1.0 )
00187 {
00188 cout << "[MultiVertexFitter] error weight " << weight << " > 1.0 given."
00189 << endl;
00190 cout << "[MultiVertexFitter] will revert to 1.0" << endl;
00191 weight=1.0;
00192 };
00193 if ( weight < 0.0 )
00194 {
00195 cout << "[MultiVertexFitter] error weight " << weight << " < 0.0 given."
00196 << endl;
00197 cout << "[MultiVertexFitter] will revert to 0.0" << endl;
00198 weight=0.0;
00199 };
00200 theWeights[track->first][snr]=weight;
00201 theTracks.push_back ( track->first );
00202 };
00203 };
00204
00205
00206
00207
00208 sort ( theTracks.begin(), theTracks.end() );
00209 for ( vector< TransientTrack >::iterator i=theTracks.begin();
00210 i<theTracks.end() ; ++i )
00211 {
00212 if ( i != theTracks.begin() )
00213 {
00214 if ( (*i) == ( *(i-1) ) )
00215 {
00216 theTracks.erase ( i );
00217 };
00218 };
00219 };
00220 }
00221
00222 vector < CachingVertex<5> > MultiVertexFitter::vertices (
00223 const vector < TransientVertex > & vtces,
00224 const vector < TransientTrack > & primaries )
00225 {
00226
00227 if ( vtces.size() < 1 )
00228 {
00229 return vector < CachingVertex<5> > ();
00230 };
00231 vector < vector < TrackAndWeight > > bundles;
00232 for ( vector< TransientVertex >::const_iterator vtx=vtces.begin();
00233 vtx!=vtces.end() ; ++vtx )
00234 {
00235 vector < TransientTrack > trks = vtx->originalTracks();
00236 vector < TrackAndWeight > tnws;
00237 for ( vector< TransientTrack >::const_iterator trk=trks.begin();
00238 trk!=trks.end() ; ++trk )
00239 {
00240 float w = vtx->trackWeight ( *trk );
00241 if ( w > 1e-5 )
00242 {
00243 TrackAndWeight tmp ( *trk, w );
00244 tnws.push_back ( tmp );
00245 };
00246 };
00247 bundles.push_back ( tnws );
00248 };
00249 return vertices ( bundles, primaries );
00250 }
00251
00252 vector < CachingVertex<5> > MultiVertexFitter::vertices (
00253 const vector < CachingVertex<5> > & initials,
00254 const vector < TransientTrack > & primaries )
00255 {
00256 clear();
00257 createPrimaries ( primaries );
00258
00259 if ( initials.size() < 1 ) return initials;
00260 for ( vector< CachingVertex<5> >::const_iterator vtx=initials.begin();
00261 vtx!=initials.end() ; ++vtx )
00262 {
00263 int snr = seedNr();
00264 theVertexStates.push_back ( pair < int, CachingVertex<5> >
00265 ( snr, *vtx ) );
00266 TransientVertex rvtx = *vtx;
00267 const vector < TransientTrack > & trks = rvtx.originalTracks();
00268 for ( vector< TransientTrack >::const_iterator trk=trks.begin();
00269 trk!=trks.end() ; ++trk )
00270 {
00271 if ( !(thePrimaries.count (*trk )) )
00272 {
00273
00274 theTracks.push_back ( *trk );
00275 } else {
00276
00277 }
00278 cout << "[MultiVertexFitter] error! track weight currently set to one"
00279 << " FIXME!!!" << endl;
00280 theWeights[*trk][snr]=1.0;
00281 };
00282 };
00283 #ifdef MVFHarvestingDebug
00284 for ( vector< CachingVertex<5> >::const_iterator i=theVertexStates.begin();
00285 i!=theVertexStates.end() ; ++i )
00286 PrimitivesHarvester::file()->save(*i);
00287 #endif
00288 return fit();
00289 }
00290
00291 vector < CachingVertex<5> > MultiVertexFitter::vertices (
00292 const vector < vector < TransientTrack > > & tracks,
00293 const vector < TransientTrack > & primaries )
00294 {
00295 clear();
00296 createPrimaries ( primaries );
00297
00298 for ( vector< vector < TransientTrack > >::const_iterator cluster=
00299 tracks.begin(); cluster!=tracks.end() ; ++cluster )
00300 {
00301 createSeed ( *cluster );
00302 };
00303 if ( verbose() )
00304 {
00305 printSeeds();
00306 };
00307 #ifdef MVFHarvestingDebug
00308 for ( vector< CachingVertex<5> >::const_iterator i=theVertexStates.begin();
00309 i!=theVertexStates.end() ; ++i )
00310 PrimitivesHarvester::file()->save(*i);
00311 #endif
00312 return fit();
00313 }
00314
00315 vector < CachingVertex<5> > MultiVertexFitter::vertices (
00316 const vector < vector < TrackAndWeight > > & tracks,
00317 const vector < TransientTrack > & primaries )
00318 {
00319 clear();
00320 createPrimaries ( primaries );
00321
00322 for ( vector< vector < TrackAndWeight > >::const_iterator cluster=
00323 tracks.begin(); cluster!=tracks.end() ; ++cluster )
00324 {
00325 createSeed ( *cluster );
00326 };
00327 if ( verbose() )
00328 {
00329 printSeeds();
00330 };
00331
00332 return fit();
00333 }
00334
00335 MultiVertexFitter::MultiVertexFitter( const AnnealingSchedule & ann,
00336 const LinearizationPointFinder & seeder,
00337 float revive_below ) :
00338 theVertexStateNr ( 0 ), theReviveBelow ( revive_below ),
00339 theAssComp ( ann.clone() ), theSeeder ( seeder.clone() )
00340 {}
00341
00342 MultiVertexFitter::MultiVertexFitter( const MultiVertexFitter & o ) :
00343 theVertexStateNr ( o.theVertexStateNr ), theReviveBelow ( o.theReviveBelow ),
00344 theAssComp ( o.theAssComp->clone() ), theSeeder ( o.theSeeder->clone() )
00345 {}
00346
00347 MultiVertexFitter::~MultiVertexFitter()
00348 {
00349 delete theAssComp;
00350 delete theSeeder;
00351 }
00352
00353 void MultiVertexFitter::updateWeights()
00354 {
00355 theWeights.clear();
00356 if ( verbose() & 4 )
00357 {
00358 cout << "[MultiVertexFitter] Start weight update." << endl;
00359 };
00360
00361 KalmanVertexTrackCompatibilityEstimator<5> theComp;
00362
00366 for ( set < TransientTrack >::const_iterator trk=thePrimaries.begin();
00367 trk!=thePrimaries.end() ; ++trk )
00368 {
00369 int seednr = theVertexStates[0].first;
00370 CachingVertex<5> seed = theVertexStates[0].second;
00371 double chi2 = theComp.estimate ( seed, theCache.linTrack ( seed.position(), *trk ) );
00372 double weight = theAssComp->phi ( chi2 );
00373 theWeights[*trk][seednr]= weight;
00374 }
00375
00379 for ( vector< TransientTrack >::const_iterator trk=theTracks.begin();
00380 trk!=theTracks.end() ; ++trk )
00381 {
00382 double tot_weight=theAssComp->phi ( theAssComp->cutoff() * theAssComp->cutoff() );
00383
00384 for ( vector < pair < int, CachingVertex<5> > >::const_iterator
00385 seed=theVertexStates.begin(); seed!=theVertexStates.end(); ++seed )
00386 {
00387
00388 double chi2 = theComp.estimate ( seed->second, theCache.linTrack ( seed->second.position(),
00389 *trk ) );
00390 double weight = theAssComp->phi ( chi2 );
00391 tot_weight+=weight;
00392 theWeights[*trk][seed->first]=weight;
00393
00394
00395 };
00396
00397
00398
00399 if ( tot_weight > 0.0 )
00400 {
00401 for ( vector < pair < int, CachingVertex<5> > >::const_iterator
00402 seed=theVertexStates.begin();
00403 seed!=theVertexStates.end(); ++seed )
00404 {
00405 double normedweight=theWeights[*trk][seed->first]/tot_weight;
00406 if ( normedweight > 1.0 )
00407 {
00408 cout << "[MultiVertexFitter] he? w["
00409 << "," << seed->second.position() << "] = " << normedweight
00410 << " totw=" << tot_weight << endl;
00411 normedweight=1.0;
00412 };
00413 if ( normedweight < 0.0 )
00414 {
00415 cout << "[MultiVertexFitter] he? weight=" << normedweight
00416 << " totw=" << tot_weight << endl;
00417 normedweight=0.0;
00418 };
00419 theWeights[*trk][seed->first]=normedweight;
00420 };
00421 } else {
00422
00423 cout << "[MultiVertexFitter] track found with no assignment - ";
00424 cout << "will assign uniformly." << endl;
00425 float w = .5 / (float) theVertexStates.size();
00426 for ( vector < pair < int, CachingVertex<5> > >::const_iterator seed=theVertexStates.begin();
00427 seed!=theVertexStates.end(); ++seed )
00428 {
00429 theWeights[*trk][seed->first]=w;
00430 };
00431 };
00432 };
00433 if ( verbose() & 2 ) printWeights();
00434 }
00435
00436 bool MultiVertexFitter::updateSeeds()
00437 {
00438 double max_disp=0.;
00439
00440
00441
00442
00443 vector < pair < int, CachingVertex<5> > > newSeeds;
00444
00445 for ( vector< pair < int, CachingVertex<5> > >::const_iterator seed=theVertexStates.begin();
00446 seed!=theVertexStates.end() ; ++seed )
00447 {
00448
00449
00450
00451 int snr = seed->first;
00452 VertexState realseed ( seed->second.position(), seed->second.error() );
00453
00454 double totweight=0.;
00455 for ( vector< TransientTrack >::const_iterator track=theTracks.begin();
00456 track!=theTracks.end() ; ++track )
00457 {
00458 totweight+=theWeights[*track][snr];
00459 };
00460
00461
00462 int nr_good_trks=0;
00463
00464
00465
00466
00467 if ( discardLightWeights() )
00468 {
00469 for ( vector< TransientTrack >::const_iterator track=theTracks.begin();
00470 track!=theTracks.end() ; ++track )
00471 {
00472 if ( theWeights[*track][snr] > totweight * minWeightFraction() )
00473 {
00474 nr_good_trks++;
00475 };
00476 };
00477 };
00478
00479 vector<RefCountedVertexTrack> newTracks;
00480 for ( vector< TransientTrack >::const_iterator track=theTracks.begin();
00481 track!=theTracks.end() ; ++track )
00482 {
00483 double weight = validWeight ( theWeights[*track][snr] );
00484
00485
00486
00487
00488
00489
00490 if ( !discardLightWeights() || weight > minWeightFraction() * totweight
00491 || nr_good_trks < 2 )
00492 {
00493
00494
00495
00496
00497 RefCountedLinearizedTrackState lTrData =
00498 theCache.linTrack ( seed->second.position(), *track );
00499
00500 VertexTrackFactory<5> vTrackFactory;
00501 RefCountedVertexTrack vTrData = vTrackFactory.vertexTrack(
00502 lTrData, realseed, weight );
00503 newTracks.push_back ( vTrData );
00504 };
00505 };
00506
00507 for ( set< TransientTrack >::const_iterator track=thePrimaries.begin();
00508 track!=thePrimaries.end() ; ++track )
00509 {
00510 double weight = validWeight ( theWeights[*track][snr] );
00511
00512 RefCountedLinearizedTrackState lTrData =
00513 theCache.linTrack ( seed->second.position(), *track );
00514
00515 VertexTrackFactory<5> vTrackFactory;
00516 RefCountedVertexTrack vTrData = vTrackFactory.vertexTrack(
00517 lTrData, realseed, weight );
00518 newTracks.push_back ( vTrData );
00519
00520 };
00521
00522 try {
00523 if ( newTracks.size() < 2 )
00524 {
00525 throw VertexException("less than two tracks in vector" );
00526 };
00527
00528 if ( verbose() )
00529 {
00530 cout << "[MultiVertexFitter] now fitting with Kalman: ";
00531 for ( vector< RefCountedVertexTrack >::const_iterator i=newTracks.begin();
00532 i!=newTracks.end() ; ++i )
00533 {
00534 cout << (**i).weight() << " ";
00535 };
00536 cout << endl;
00537 };
00538
00539 if ( newTracks.size() > 1 )
00540 {
00541 KalmanVertexFitter fitter;
00542
00543 CachingVertex<5> newVertex = fitter.vertex ( newTracks );
00544 int snr = seedNr();
00545 double disp = ( newVertex.position() - seed->second.position() ).mag();
00546 if ( disp > max_disp ) max_disp = disp;
00547 newSeeds.push_back (
00548 pair < int, CachingVertex<5> > ( snr, newVertex ) );
00549 };
00550 } catch ( exception & e )
00551 {
00552 cout << "[MultiVertexFitter] exception: " << e.what() << endl;
00553 }
00554 };
00555
00556
00557 theVertexStates.clear();
00558 theWeights.clear();
00559 theVertexStates=newSeeds;
00560 #ifdef MVFHarvestingDebug
00561 for ( vector< CachingVertex<5> >::const_iterator i=theVertexStates.begin();
00562 i!=theVertexStates.end() ; ++i )
00563 PrimitivesHarvester::file()->save(*i);
00564 #endif
00565 updateWeights();
00566
00567 static const double disp_limit = 1e-4;
00568
00569
00570 if ( verbose() & 2 )
00571 {
00572 printSeeds();
00573 cout << "[MultiVertexFitter] max displacement in this iteration: "
00574 << max_disp << endl;
00575 };
00576 if ( max_disp < disp_limit ) return false;
00577 return true;
00578 }
00579
00580
00581 vector < CachingVertex<5> > MultiVertexFitter::fit()
00582 {
00583 if ( verbose() & 2 ) printWeights();
00584 int ctr=1;
00585 static const int ctr_max = 50;
00586
00587 while ( updateSeeds() || !(theAssComp->isAnnealed()) )
00588 {
00589 if ( ++ctr >= ctr_max ) break;
00590 theAssComp->anneal();
00591
00592 resetSeedNr();
00593 };
00594
00595 if ( verbose() )
00596 {
00597 cout << "[MultiVertexFitter] number of iterations: " << ctr << endl;
00598 cout << "[MultiVertexFitter] remaining seeds: "
00599 << theVertexStates.size() << endl;
00600 printWeights();
00601 };
00602
00603 vector < CachingVertex<5> > ret;
00604 for ( vector< pair < int, CachingVertex<5> > >::const_iterator
00605 i=theVertexStates.begin(); i!=theVertexStates.end() ; ++i )
00606 {
00607 ret.push_back ( i->second );
00608 };
00609
00610 return ret;
00611 }
00612
00613 void MultiVertexFitter::printWeights ( const reco::TransientTrack & t ) const
00614 {
00615
00616 for ( vector < pair < int, CachingVertex<5> > >::const_iterator seed=theVertexStates.begin();
00617 seed!=theVertexStates.end(); ++seed )
00618 {
00619 cout << " -- Vertex[" << seed->first << "] with " << setw(12)
00620 << setprecision(3) << theWeights[t][seed->first];
00621 };
00622 cout << endl;
00623 }
00624
00625 void MultiVertexFitter::printWeights() const
00626 {
00627 cout << endl << "Weight table: " << endl << "=================" << endl;
00628 for ( set < TransientTrack >::const_iterator trk=thePrimaries.begin();
00629 trk!=thePrimaries.end() ; ++trk )
00630 {
00631 printWeights ( *trk );
00632 };
00633 for ( vector< TransientTrack >::const_iterator trk=theTracks.begin();
00634 trk!=theTracks.end() ; ++trk )
00635 {
00636 printWeights ( *trk );
00637 };
00638 }
00639
00640 void MultiVertexFitter::printSeeds() const
00641 {
00642 cout << endl << "Seed table: " << endl << "=====================" << endl;
00643
00644
00645
00646
00647
00648
00649
00650 }
00651
00652 void MultiVertexFitter::lostVertexClaimer()
00653 {
00654 if ( !theReviveBelow < 0. ) return;
00655
00656
00657
00658 bool has_revived = false;
00659
00660 for ( vector< pair < int, CachingVertex<5> > >::const_iterator i=theVertexStates.begin();
00661 i!=theVertexStates.end() ; ++i )
00662 {
00663 double totweight=0.;
00664 for ( vector< TransientTrack >::const_iterator trk=theTracks.begin();
00665 trk!=theTracks.end() ; ++trk )
00666 {
00667 totweight+=theWeights[*trk][i->first];
00668 };
00669
00670
00671
00672
00673
00674 if ( totweight < theReviveBelow && totweight > 0.0 )
00675 {
00676 cout << "[MultiVertexFitter] now trying to revive vertex"
00677 << " revive_below=" << theReviveBelow << endl;
00678 has_revived=true;
00679 for ( vector< TransientTrack >::const_iterator trk=theTracks.begin();
00680 trk!=theTracks.end() ; ++trk )
00681 {
00682 theWeights[*trk][i->first]/=totweight;
00683 };
00684 };
00685 };
00686 if ( has_revived && verbose() ) printWeights();
00687 }