00001 #include "RecoVertex/MultiVertexFit/interface/MultiVertexFitter.h"
00002
00003 #include <map>
00004 #include <algorithm>
00005 #include <iomanip>
00006
00007
00008 #include "RecoVertex/KalmanVertexFit/interface/KalmanVertexFitter.h"
00009 #include "RecoVertex/VertexTools/interface/LinearizedTrackStateFactory.h"
00010 #include "RecoVertex/VertexTools/interface/VertexTrackFactory.h"
00011 #include "RecoVertex/VertexPrimitives/interface/VertexState.h"
00012 #include "RecoVertex/VertexPrimitives/interface/VertexException.h"
00013 #include "RecoVertex/KalmanVertexFit/interface/KalmanVertexTrackCompatibilityEstimator.h"
00014
00015
00016 #ifdef MVFHarvestingDebug
00017 #include "Vertex/VertexSimpleVis/interface/PrimitivesHarvester.h"
00018 #endif
00019
00020 using namespace std;
00021 using namespace reco;
00022
00023 namespace
00024 {
00025 typedef MultiVertexFitter::TrackAndWeight TrackAndWeight;
00026 typedef MultiVertexFitter::TrackAndSeedToWeightMap TrackAndSeedToWeightMap;
00027 typedef MultiVertexFitter::SeedToWeightMap SeedToWeightMap;
00028 typedef CachingVertex<5>::RefCountedVertexTrack RefCountedVertexTrack;
00029
00030 int verbose()
00031 {
00032 static const int ret = 0;
00033
00034 return ret;
00035 }
00036
00037 double minWeightFraction()
00038 {
00039
00040
00041
00042
00043 static const float ret = 1e-6;
00044
00045 return ret;
00046 }
00047
00048 bool discardLightWeights()
00049 {
00050 static const bool ret = true;
00051
00052 return ret;
00053 }
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065 CachingVertex<5> createSeedFromLinPt ( const GlobalPoint & gp )
00066 {
00067 return CachingVertex<5> ( gp, GlobalError (),
00068 vector<RefCountedVertexTrack> (), 0.0 );
00069 }
00070
00071 double validWeight ( double weight )
00072 {
00073 if ( weight > 1.0 )
00074 {
00075 cout << "[MultiVertexFitter] weight=" << weight << "??" << endl;
00076 return 1.0;
00077 };
00078
00079 if ( weight < 0.0 )
00080 {
00081 cout << "[MultiVertexFitter] weight=" << weight << "??" << endl;
00082 return 0.0;
00083 };
00084 return weight;
00085 }
00086 }
00087
00088 void MultiVertexFitter::clear()
00089 {
00090 theAssComp->resetAnnealing();
00091 theTracks.clear();
00092 thePrimaries.clear();
00093 theVertexStates.clear();
00094 theWeights.clear();
00095 theCache.clear();
00096 }
00097
00098
00099
00100
00101 void MultiVertexFitter::createSeed( const vector < TransientTrack > & tracks )
00102 {
00103 if ( tracks.size() > 1 )
00104 {
00105 CachingVertex<5> vtx = createSeedFromLinPt (
00106 theSeeder->getLinearizationPoint ( tracks ) );
00107 int snr= seedNr();
00108 theVertexStates.push_back ( pair < int, CachingVertex<5> > ( snr, vtx ) );
00109 for ( vector< TransientTrack >::const_iterator track=tracks.begin();
00110 track!=tracks.end() ; ++track )
00111 {
00112 theWeights[*track][snr]=1.;
00113 theTracks.push_back ( *track );
00114 };
00115 };
00116 }
00117
00118 void MultiVertexFitter::createPrimaries ( const std::vector < reco::TransientTrack > &tracks )
00119 {
00120
00121 for ( vector< reco::TransientTrack >::const_iterator i=tracks.begin();
00122 i!=tracks.end() ; ++i )
00123 {
00124 thePrimaries.insert ( *i );
00125
00126 }
00127
00128 }
00129
00130 int MultiVertexFitter::seedNr()
00131 {
00132 return theVertexStateNr++;
00133 }
00134
00135 void MultiVertexFitter::resetSeedNr()
00136 {
00137 theVertexStateNr=0;
00138 }
00139
00140 void MultiVertexFitter::createSeed( const vector < TrackAndWeight > & tracks )
00141 {
00142
00143 vector < RefCountedVertexTrack> newTracks;
00144
00145 for ( vector< TrackAndWeight >::const_iterator track=tracks.begin();
00146 track!=tracks.end() ; ++track )
00147 {
00148 double weight = validWeight ( track->second );
00149 const GlobalPoint & pos = track->first.impactPointState().globalPosition();
00150 GlobalError err;
00151 VertexState realseed ( pos, err );
00152
00153 RefCountedLinearizedTrackState lTrData =
00154 theCache.linTrack ( pos, track->first );
00155
00156 VertexTrackFactory<5> vTrackFactory;
00157 RefCountedVertexTrack vTrData = vTrackFactory.vertexTrack(
00158 lTrData, realseed, weight );
00159 newTracks.push_back ( vTrData );
00160 };
00161
00162 if ( newTracks.size() > 1 )
00163 {
00164 CachingVertex<5> vtx = KalmanVertexFitter().vertex ( newTracks );
00165 int snr = seedNr();
00166 theVertexStates.push_back ( pair < int, CachingVertex<5> > ( snr, vtx ) );
00167
00168
00169
00170 for ( vector< TrackAndWeight >::const_iterator track=tracks.begin();
00171 track!=tracks.end() ; ++track )
00172 {
00173 if ( thePrimaries.count ( track->first ) )
00174 {
00175
00176
00177
00178
00179
00180
00181 theWeights[track->first][theVertexStates[0].first]=track->second;
00182 continue;
00183 };
00184 float weight = track->second;
00185 if ( weight > 1.0 )
00186 {
00187 cout << "[MultiVertexFitter] error weight " << weight << " > 1.0 given."
00188 << endl;
00189 cout << "[MultiVertexFitter] will revert to 1.0" << endl;
00190 weight=1.0;
00191 };
00192 if ( weight < 0.0 )
00193 {
00194 cout << "[MultiVertexFitter] error weight " << weight << " < 0.0 given."
00195 << endl;
00196 cout << "[MultiVertexFitter] will revert to 0.0" << endl;
00197 weight=0.0;
00198 };
00199 theWeights[track->first][snr]=weight;
00200 theTracks.push_back ( track->first );
00201 };
00202 };
00203
00204
00205
00206
00207 sort ( theTracks.begin(), theTracks.end() );
00208 for ( vector< TransientTrack >::iterator i=theTracks.begin();
00209 i<theTracks.end() ; ++i )
00210 {
00211 if ( i != theTracks.begin() )
00212 {
00213 if ( (*i) == ( *(i-1) ) )
00214 {
00215 theTracks.erase ( i );
00216 };
00217 };
00218 };
00219 }
00220
00221 vector < CachingVertex<5> > MultiVertexFitter::vertices (
00222 const vector < TransientVertex > & vtces,
00223 const vector < TransientTrack > & primaries )
00224 {
00225
00226 if ( vtces.size() < 1 )
00227 {
00228 return vector < CachingVertex<5> > ();
00229 };
00230 vector < vector < TrackAndWeight > > bundles;
00231 for ( vector< TransientVertex >::const_iterator vtx=vtces.begin();
00232 vtx!=vtces.end() ; ++vtx )
00233 {
00234 vector < TransientTrack > trks = vtx->originalTracks();
00235 vector < TrackAndWeight > tnws;
00236 for ( vector< TransientTrack >::const_iterator trk=trks.begin();
00237 trk!=trks.end() ; ++trk )
00238 {
00239 float w = vtx->trackWeight ( *trk );
00240 if ( w > 1e-5 )
00241 {
00242 TrackAndWeight tmp ( *trk, w );
00243 tnws.push_back ( tmp );
00244 };
00245 };
00246 bundles.push_back ( tnws );
00247 };
00248 return vertices ( bundles, primaries );
00249 }
00250
00251 vector < CachingVertex<5> > MultiVertexFitter::vertices (
00252 const vector < CachingVertex<5> > & initials,
00253 const vector < TransientTrack > & primaries )
00254 {
00255 clear();
00256 createPrimaries ( primaries );
00257
00258 if ( initials.size() < 1 ) return initials;
00259 for ( vector< CachingVertex<5> >::const_iterator vtx=initials.begin();
00260 vtx!=initials.end() ; ++vtx )
00261 {
00262 int snr = seedNr();
00263 theVertexStates.push_back ( pair < int, CachingVertex<5> >
00264 ( snr, *vtx ) );
00265 TransientVertex rvtx = *vtx;
00266 const vector < TransientTrack > & trks = rvtx.originalTracks();
00267 for ( vector< TransientTrack >::const_iterator trk=trks.begin();
00268 trk!=trks.end() ; ++trk )
00269 {
00270 if ( !(thePrimaries.count (*trk )) )
00271 {
00272
00273 theTracks.push_back ( *trk );
00274 } else {
00275
00276 }
00277 cout << "[MultiVertexFitter] error! track weight currently set to one"
00278 << " FIXME!!!" << endl;
00279 theWeights[*trk][snr]=1.0;
00280 };
00281 };
00282 #ifdef MVFHarvestingDebug
00283 for ( vector< CachingVertex<5> >::const_iterator i=theVertexStates.begin();
00284 i!=theVertexStates.end() ; ++i )
00285 PrimitivesHarvester::file()->save(*i);
00286 #endif
00287 return fit();
00288 }
00289
00290 vector < CachingVertex<5> > MultiVertexFitter::vertices (
00291 const vector < vector < TransientTrack > > & tracks,
00292 const vector < TransientTrack > & primaries )
00293 {
00294 clear();
00295 createPrimaries ( primaries );
00296
00297 for ( vector< vector < TransientTrack > >::const_iterator cluster=
00298 tracks.begin(); cluster!=tracks.end() ; ++cluster )
00299 {
00300 createSeed ( *cluster );
00301 };
00302 if ( verbose() )
00303 {
00304 printSeeds();
00305 };
00306 #ifdef MVFHarvestingDebug
00307 for ( vector< CachingVertex<5> >::const_iterator i=theVertexStates.begin();
00308 i!=theVertexStates.end() ; ++i )
00309 PrimitivesHarvester::file()->save(*i);
00310 #endif
00311 return fit();
00312 }
00313
00314 vector < CachingVertex<5> > MultiVertexFitter::vertices (
00315 const vector < vector < TrackAndWeight > > & tracks,
00316 const vector < TransientTrack > & primaries )
00317 {
00318 clear();
00319 createPrimaries ( primaries );
00320
00321 for ( vector< vector < TrackAndWeight > >::const_iterator cluster=
00322 tracks.begin(); cluster!=tracks.end() ; ++cluster )
00323 {
00324 createSeed ( *cluster );
00325 };
00326 if ( verbose() )
00327 {
00328 printSeeds();
00329 };
00330
00331 return fit();
00332 }
00333
00334 MultiVertexFitter::MultiVertexFitter( const AnnealingSchedule & ann,
00335 const LinearizationPointFinder & seeder,
00336 float revive_below ) :
00337 theVertexStateNr ( 0 ), theReviveBelow ( revive_below ),
00338 theAssComp ( ann.clone() ), theSeeder ( seeder.clone() )
00339 {}
00340
00341 MultiVertexFitter::MultiVertexFitter( const MultiVertexFitter & o ) :
00342 theVertexStateNr ( o.theVertexStateNr ), theReviveBelow ( o.theReviveBelow ),
00343 theAssComp ( o.theAssComp->clone() ), theSeeder ( o.theSeeder->clone() )
00344 {}
00345
00346 MultiVertexFitter::~MultiVertexFitter()
00347 {
00348 delete theAssComp;
00349 delete theSeeder;
00350 }
00351
00352 void MultiVertexFitter::updateWeights()
00353 {
00354 theWeights.clear();
00355 if ( verbose() & 4 )
00356 {
00357 cout << "[MultiVertexFitter] Start weight update." << endl;
00358 };
00359
00360 KalmanVertexTrackCompatibilityEstimator<5> theComp;
00361
00365 for ( set < TransientTrack >::const_iterator trk=thePrimaries.begin();
00366 trk!=thePrimaries.end() ; ++trk )
00367 {
00368 int seednr = theVertexStates[0].first;
00369 CachingVertex<5> seed = theVertexStates[0].second;
00370 pair<bool, double> result = theComp.estimate ( seed, theCache.linTrack ( seed.position(), *trk ) );
00371 double weight = 0.;
00372 if (result.first) weight = theAssComp->phi ( result.second );
00373 theWeights[*trk][seednr]= weight;
00374 }
00375
00379 for ( vector< TransientTrack >::const_iterator trk=theTracks.begin();
00380 trk!=theTracks.end() ; ++trk )
00381 {
00382 double tot_weight=theAssComp->phi ( theAssComp->cutoff() * theAssComp->cutoff() );
00383
00384 for ( vector < pair < int, CachingVertex<5> > >::const_iterator
00385 seed=theVertexStates.begin(); seed!=theVertexStates.end(); ++seed )
00386 {
00387
00388 pair<bool, double> result = theComp.estimate ( seed->second, theCache.linTrack ( seed->second.position(),
00389 *trk ) );
00390 double weight = 0.;
00391 if (result.first) weight = theAssComp->phi ( result.second );
00392 tot_weight+=weight;
00393 theWeights[*trk][seed->first]=weight;
00394
00395
00396 };
00397
00398
00399
00400 if ( tot_weight > 0.0 )
00401 {
00402 for ( vector < pair < int, CachingVertex<5> > >::const_iterator
00403 seed=theVertexStates.begin();
00404 seed!=theVertexStates.end(); ++seed )
00405 {
00406 double normedweight=theWeights[*trk][seed->first]/tot_weight;
00407 if ( normedweight > 1.0 )
00408 {
00409 cout << "[MultiVertexFitter] he? w["
00410 << "," << seed->second.position() << "] = " << normedweight
00411 << " totw=" << tot_weight << endl;
00412 normedweight=1.0;
00413 };
00414 if ( normedweight < 0.0 )
00415 {
00416 cout << "[MultiVertexFitter] he? weight=" << normedweight
00417 << " totw=" << tot_weight << endl;
00418 normedweight=0.0;
00419 };
00420 theWeights[*trk][seed->first]=normedweight;
00421 };
00422 } else {
00423
00424 cout << "[MultiVertexFitter] track found with no assignment - ";
00425 cout << "will assign uniformly." << endl;
00426 float w = .5 / (float) theVertexStates.size();
00427 for ( vector < pair < int, CachingVertex<5> > >::const_iterator seed=theVertexStates.begin();
00428 seed!=theVertexStates.end(); ++seed )
00429 {
00430 theWeights[*trk][seed->first]=w;
00431 };
00432 };
00433 };
00434 if ( verbose() & 2 ) printWeights();
00435 }
00436
00437 bool MultiVertexFitter::updateSeeds()
00438 {
00439 double max_disp=0.;
00440
00441
00442
00443
00444 vector < pair < int, CachingVertex<5> > > newSeeds;
00445
00446 for ( vector< pair < int, CachingVertex<5> > >::const_iterator seed=theVertexStates.begin();
00447 seed!=theVertexStates.end() ; ++seed )
00448 {
00449
00450
00451
00452 int snr = seed->first;
00453 VertexState realseed ( seed->second.position(), seed->second.error() );
00454
00455 double totweight=0.;
00456 for ( vector< TransientTrack >::const_iterator track=theTracks.begin();
00457 track!=theTracks.end() ; ++track )
00458 {
00459 totweight+=theWeights[*track][snr];
00460 };
00461
00462
00463 int nr_good_trks=0;
00464
00465
00466
00467
00468 if ( discardLightWeights() )
00469 {
00470 for ( vector< TransientTrack >::const_iterator track=theTracks.begin();
00471 track!=theTracks.end() ; ++track )
00472 {
00473 if ( theWeights[*track][snr] > totweight * minWeightFraction() )
00474 {
00475 nr_good_trks++;
00476 };
00477 };
00478 };
00479
00480 vector<RefCountedVertexTrack> newTracks;
00481 for ( vector< TransientTrack >::const_iterator track=theTracks.begin();
00482 track!=theTracks.end() ; ++track )
00483 {
00484 double weight = validWeight ( theWeights[*track][snr] );
00485
00486
00487
00488
00489
00490
00491 if ( !discardLightWeights() || weight > minWeightFraction() * totweight
00492 || nr_good_trks < 2 )
00493 {
00494
00495
00496
00497
00498 RefCountedLinearizedTrackState lTrData =
00499 theCache.linTrack ( seed->second.position(), *track );
00500
00501 VertexTrackFactory<5> vTrackFactory;
00502 RefCountedVertexTrack vTrData = vTrackFactory.vertexTrack(
00503 lTrData, realseed, weight );
00504 newTracks.push_back ( vTrData );
00505 };
00506 };
00507
00508 for ( set< TransientTrack >::const_iterator track=thePrimaries.begin();
00509 track!=thePrimaries.end() ; ++track )
00510 {
00511 double weight = validWeight ( theWeights[*track][snr] );
00512
00513 RefCountedLinearizedTrackState lTrData =
00514 theCache.linTrack ( seed->second.position(), *track );
00515
00516 VertexTrackFactory<5> vTrackFactory;
00517 RefCountedVertexTrack vTrData = vTrackFactory.vertexTrack(
00518 lTrData, realseed, weight );
00519 newTracks.push_back ( vTrData );
00520
00521 };
00522
00523 try {
00524 if ( newTracks.size() < 2 )
00525 {
00526 throw VertexException("less than two tracks in vector" );
00527 };
00528
00529 if ( verbose() )
00530 {
00531 cout << "[MultiVertexFitter] now fitting with Kalman: ";
00532 for ( vector< RefCountedVertexTrack >::const_iterator i=newTracks.begin();
00533 i!=newTracks.end() ; ++i )
00534 {
00535 cout << (**i).weight() << " ";
00536 };
00537 cout << endl;
00538 };
00539
00540 if ( newTracks.size() > 1 )
00541 {
00542 KalmanVertexFitter fitter;
00543
00544 CachingVertex<5> newVertex = fitter.vertex ( newTracks );
00545 int snr = seedNr();
00546 double disp = ( newVertex.position() - seed->second.position() ).mag();
00547 if ( disp > max_disp ) max_disp = disp;
00548 newSeeds.push_back (
00549 pair < int, CachingVertex<5> > ( snr, newVertex ) );
00550 };
00551 } catch ( exception & e )
00552 {
00553 cout << "[MultiVertexFitter] exception: " << e.what() << endl;
00554 }
00555 };
00556
00557
00558 theVertexStates.clear();
00559 theWeights.clear();
00560 theVertexStates=newSeeds;
00561 #ifdef MVFHarvestingDebug
00562 for ( vector< CachingVertex<5> >::const_iterator i=theVertexStates.begin();
00563 i!=theVertexStates.end() ; ++i )
00564 PrimitivesHarvester::file()->save(*i);
00565 #endif
00566 updateWeights();
00567
00568 static const double disp_limit = 1e-4;
00569
00570
00571 if ( verbose() & 2 )
00572 {
00573 printSeeds();
00574 cout << "[MultiVertexFitter] max displacement in this iteration: "
00575 << max_disp << endl;
00576 };
00577 if ( max_disp < disp_limit ) return false;
00578 return true;
00579 }
00580
00581
00582 vector < CachingVertex<5> > MultiVertexFitter::fit()
00583 {
00584 if ( verbose() & 2 ) printWeights();
00585 int ctr=1;
00586 static const int ctr_max = 50;
00587
00588 while ( updateSeeds() || !(theAssComp->isAnnealed()) )
00589 {
00590 if ( ++ctr >= ctr_max ) break;
00591 theAssComp->anneal();
00592
00593 resetSeedNr();
00594 };
00595
00596 if ( verbose() )
00597 {
00598 cout << "[MultiVertexFitter] number of iterations: " << ctr << endl;
00599 cout << "[MultiVertexFitter] remaining seeds: "
00600 << theVertexStates.size() << endl;
00601 printWeights();
00602 };
00603
00604 vector < CachingVertex<5> > ret;
00605 for ( vector< pair < int, CachingVertex<5> > >::const_iterator
00606 i=theVertexStates.begin(); i!=theVertexStates.end() ; ++i )
00607 {
00608 ret.push_back ( i->second );
00609 };
00610
00611 return ret;
00612 }
00613
00614 void MultiVertexFitter::printWeights ( const reco::TransientTrack & t ) const
00615 {
00616
00617 for ( vector < pair < int, CachingVertex<5> > >::const_iterator seed=theVertexStates.begin();
00618 seed!=theVertexStates.end(); ++seed )
00619 {
00620 cout << " -- Vertex[" << seed->first << "] with " << setw(12)
00621 << setprecision(3) << theWeights[t][seed->first];
00622 };
00623 cout << endl;
00624 }
00625
00626 void MultiVertexFitter::printWeights() const
00627 {
00628 cout << endl << "Weight table: " << endl << "=================" << endl;
00629 for ( set < TransientTrack >::const_iterator trk=thePrimaries.begin();
00630 trk!=thePrimaries.end() ; ++trk )
00631 {
00632 printWeights ( *trk );
00633 };
00634 for ( vector< TransientTrack >::const_iterator trk=theTracks.begin();
00635 trk!=theTracks.end() ; ++trk )
00636 {
00637 printWeights ( *trk );
00638 };
00639 }
00640
00641 void MultiVertexFitter::printSeeds() const
00642 {
00643 cout << endl << "Seed table: " << endl << "=====================" << endl;
00644
00645
00646
00647
00648
00649
00650
00651 }
00652
00653 void MultiVertexFitter::lostVertexClaimer()
00654 {
00655 if ( !theReviveBelow < 0. ) return;
00656
00657
00658
00659 bool has_revived = false;
00660
00661 for ( vector< pair < int, CachingVertex<5> > >::const_iterator i=theVertexStates.begin();
00662 i!=theVertexStates.end() ; ++i )
00663 {
00664 double totweight=0.;
00665 for ( vector< TransientTrack >::const_iterator trk=theTracks.begin();
00666 trk!=theTracks.end() ; ++trk )
00667 {
00668 totweight+=theWeights[*trk][i->first];
00669 };
00670
00671
00672
00673
00674
00675 if ( totweight < theReviveBelow && totweight > 0.0 )
00676 {
00677 cout << "[MultiVertexFitter] now trying to revive vertex"
00678 << " revive_below=" << theReviveBelow << endl;
00679 has_revived=true;
00680 for ( vector< TransientTrack >::const_iterator trk=theTracks.begin();
00681 trk!=theTracks.end() ; ++trk )
00682 {
00683 theWeights[*trk][i->first]/=totweight;
00684 };
00685 };
00686 };
00687 if ( has_revived && verbose() ) printWeights();
00688 }