CMS 3D CMS Logo

PhysicsTools::MVATrainer Class Reference

#include <PhysicsTools/MVATrainer/interface/MVATrainer.h>

List of all members.

Public Member Functions

TrainerMonitoring::ModulebookMonitor (const std::string &name)
void doneTraining (Calibration::MVAComputer *trainCalibration) const
Calibration::MVAComputergetCalibration () const
const std::string & getName () const
Calibration::MVAComputergetTrainCalibration () const
void loadState ()
 MVATrainer (const std::string &fileName)
void saveState ()
void setAutoSave (bool autoSave)
void setCleanup (bool cleanup)
void setCrossValidation (double split)
void setMonitoring (bool monitoring)
void setRandomSeed (UInt_t seed)
std::string trainFileName (const TrainProcessor *proc, const std::string &ext, const std::string &arg="") const
 ~MVATrainer ()

Static Public Attributes

static const AtomicId kTargetId
static const AtomicId kWeightId

Private Member Functions

void connectProcessors (Calibration::MVAComputer *calib, const std::vector< CalibratedProcessor > &procs, bool withTarget) const
SourceVariablecreateVariable (Source *source, AtomicId name, Variable::Flags flags)
void fillInputVars (SourceVariableSet &vars, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
void fillOutputVars (SourceVariableSet &vars, Source *source, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
std::vector< AtomicIdfindFinalProcessors () const
void findUntrainedComputers (std::vector< AtomicId > &compute, std::vector< AtomicId > &train) const
SourceVariablegetVariable (AtomicId source, AtomicId name) const
void makeProcessor (XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *elem, AtomicId id, const char *name)
Calibration::MVAComputermakeTrainCalibration (const AtomicId *compute, const AtomicId *train) const

Private Attributes

double crossValidation
bool doAutoSave
bool doCleanup
bool doMonitoring
Sourceinput
std::auto_ptr< TrainerMonitoringmonitoring
std::string name
TrainProcessoroutput
std::vector< AtomicIdprocessors
UInt_t randomSeed
std::map< AtomicId, Source * > sources
std::string trainFileMask
std::vector< SourceVariable * > variables
std::auto_ptr< XMLDocumentxml

Classes

struct  CalibratedProcessor


Detailed Description

Definition at line 27 of file MVATrainer.h.


Constructor & Destructor Documentation

PhysicsTools::MVATrainer::MVATrainer ( const std::string &  fileName  ) 

Definition at line 404 of file MVATrainer.cc.

References createVariable(), lat::endl(), Exception, fillInputVars(), fillOutputVars(), PhysicsTools::Variable::FLAG_NONE, PhysicsTools::Variable::FLAG_OPTIONAL, PhysicsTools::Source::getInputs(), input, PhysicsTools::SourceVariableSet::kTarget, kTargetId, PhysicsTools::SourceVariableSet::kWeight, kWeightId, makeProcessor(), name, prof2calltree::node, output, sources, state, trainFileMask, and xml.

00404                                                 :
00405         input(0), output(0), name("MVATrainer"),
00406         doAutoSave(true), doCleanup(false), doMonitoring(false),
00407         randomSeed(65539), crossValidation(0.0)
00408 {
00409         xml = std::auto_ptr<XMLDocument>(new XMLDocument(fileName));
00410 
00411         DOMNode *node = xml->getRootNode();
00412 
00413         if (std::strcmp(XMLSimpleStr(node->getNodeName()), "MVATrainer") != 0)
00414                 throw cms::Exception("MVATrainer")
00415                         << "Invalid XML root node." << std::endl;
00416 
00417         enum State {
00418                 STATE_GENERAL,
00419                 STATE_FIRST,
00420                 STATE_MIDDLE,
00421                 STATE_LAST
00422         } state = STATE_GENERAL;
00423 
00424         for(node = node->getFirstChild();
00425             node; node = node->getNextSibling()) {
00426                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00427                         continue;
00428 
00429                 std::string name = XMLSimpleStr(node->getNodeName());
00430                 DOMElement *elem = static_cast<DOMElement*>(node);
00431 
00432                 switch(state) {
00433                     case STATE_GENERAL: {
00434                         if (name != "general")
00435                                 throw cms::Exception("MVATrainer")
00436                                         << "Expected general config as first "
00437                                            "tag." << std::endl;
00438 
00439                         for(DOMNode *subNode = elem->getFirstChild();
00440                             subNode; subNode = subNode->getNextSibling()) {
00441                                 if (subNode->getNodeType() !=
00442                                     DOMNode::ELEMENT_NODE)
00443                                         continue;
00444 
00445                                 if (std::strcmp(XMLSimpleStr(
00446                                         subNode->getNodeName()), "option") != 0)
00447                                         throw cms::Exception("MVATrainer")
00448                                                 << "Expected option tag."
00449                                                 << std::endl;
00450 
00451                                 elem = static_cast<DOMElement*>(subNode);
00452                                 name = XMLDocument::readAttribute<std::string>(
00453                                                                 elem, "name");
00454                                 std::string content = XMLSimpleStr(
00455                                                 elem->getTextContent());
00456 
00457                                 if (name == "id")
00458                                         this->name = content;
00459                                 else if (name == "trainfiles")
00460                                         trainFileMask = content;
00461                                 else
00462                                         throw cms::Exception("MVATrainer")
00463                                                 << "Unknown option \""
00464                                                 << name << "\"." << std::endl;
00465                         }
00466 
00467                         state = STATE_FIRST;
00468                     }   break;
00469                     case STATE_FIRST: {
00470                         if (name != "input")
00471                                 throw cms::Exception("MVATrainer")
00472                                         << "Expected input config as second "
00473                                            "tag." << std::endl;
00474 
00475                         AtomicId id = XMLDocument::readAttribute<std::string>(
00476                                                                 elem, "id");
00477                         input = new Source(id, true);
00478                         input->getOutputs().append(
00479                                 createVariable(input, kTargetId,
00480                                                Variable::FLAG_NONE),
00481                                 SourceVariableSet::kTarget);
00482                         input->getOutputs().append(
00483                                 createVariable(input, kWeightId,
00484                                                Variable::FLAG_OPTIONAL),
00485                                 SourceVariableSet::kWeight);
00486                         sources.insert(std::make_pair(id, input));
00487                         fillOutputVars(input->getOutputs(), input, elem);
00488 
00489                         state = STATE_MIDDLE;
00490                     }   break;
00491                     case STATE_MIDDLE: {
00492                         if (name == "output") {
00493                                 AtomicId zero;
00494                                 output = new TrainProcessor("output",
00495                                                             &zero, this);
00496                                 fillInputVars(output->getInputs(), elem);
00497                                 state = STATE_LAST;
00498                                 continue;
00499                         } else if (name != "processor")
00500                                 throw cms::Exception("MVATrainer")
00501                                         << "Unexpected tag after input "
00502                                            "config." << std::endl;
00503 
00504                         AtomicId id = XMLDocument::readAttribute<std::string>(
00505                                                                 elem, "id");
00506                         std::string name =
00507                                 XMLDocument::readAttribute<std::string>(
00508                                         elem, "name");
00509 
00510                         makeProcessor(elem, id, name.c_str());
00511                     }   break;
00512                     case STATE_LAST:
00513                         throw cms::Exception("MVATrainer")
00514                                 << "Unexpected tag found after output."
00515                                 << std::endl;
00516                         break;
00517                 }
00518         }
00519 
00520         if (state == STATE_FIRST)
00521                 throw cms::Exception("MVATrainer")
00522                         << "Expected input variable config." << std::endl;
00523         else if (state == STATE_MIDDLE)
00524                 throw cms::Exception("MVATrainer")
00525                         << "Expected output variable config." << std::endl;
00526 
00527         if (trainFileMask.empty())
00528                 trainFileMask = this->name + "_%s%s.%s";
00529 }

PhysicsTools::MVATrainer::~MVATrainer (  ) 

Definition at line 531 of file MVATrainer.cc.

References PhysicsTools::TrainProcessor::cleanup(), doCleanup, iter, monitoring, output, proc, sources, and variables.

00532 {
00533         if (monitoring.get())
00534                 monitoring->write();
00535 
00536         for(std::map<AtomicId, Source*>::const_iterator iter = sources.begin();
00537             iter != sources.end(); iter++) {
00538                 TrainProcessor *proc =
00539                                 dynamic_cast<TrainProcessor*>(iter->second);
00540 
00541                 if (proc && doCleanup)
00542                         proc->cleanup();
00543 
00544                 delete iter->second;
00545         }
00546         delete output;
00547         std::for_each(variables.begin(), variables.end(),
00548                       deleter<SourceVariable>());
00549 }


Member Function Documentation

TrainerMonitoring::Module * PhysicsTools::MVATrainer::bookMonitor ( const std::string &  name  ) 

Definition at line 678 of file MVATrainer.cc.

References doMonitoring, aod_PYTHIA_cfg::fileName, monitoring, PhysicsTools::stdStringPrintf(), and trainFileMask.

Referenced by PhysicsTools::TrainProcessor::doTrainBegin().

00679 {
00680         if (!doMonitoring)
00681                 return 0;
00682 
00683         if (!monitoring.get()) {
00684                 std::string fileName = 
00685                         stdStringPrintf(trainFileMask.c_str(),
00686                                         "monitoring", "", "root");
00687                 monitoring.reset(new TrainerMonitoring(fileName));
00688         }
00689 
00690         return monitoring->book(name);
00691 }

void PhysicsTools::MVATrainer::connectProcessors ( Calibration::MVAComputer calib,
const std::vector< CalibratedProcessor > &  procs,
bool  withTarget 
) const [private]

Definition at line 853 of file MVATrainer.cc.

References PhysicsTools::Calibration::MVAComputer::addProcessor(), convert(), Exception, PhysicsTools::SourceVariableSet::get(), PhysicsTools::Variable::getFlags(), PhysicsTools::Source::getInputs(), PhysicsTools::Variable::getName(), PhysicsTools::Source::getOutputs(), i, input, PhysicsTools::Calibration::MVAComputer::inputSet, iter, prof2calltree::last, python::multivaluedict::map(), PhysicsTools::Calibration::Variable::name, output, PhysicsTools::Calibration::MVAComputer::output, PhysicsTools::SourceVariableSet::size(), size, variables, and vars.

Referenced by getCalibration(), and makeTrainCalibration().

00856 {
00857         std::map<SourceVariable*, unsigned int> vars;
00858         unsigned int size = 0;
00859 
00860         MVATrainerComputer *trainCalib =
00861                         dynamic_cast<MVATrainerComputer*>(calib);
00862 
00863         for(unsigned int i = 0;
00864             i < input->getOutputs().size(true); i++) {
00865                 if (i < 2 && !withTarget)
00866                         continue;
00867 
00868                 SourceVariable *var = variables[i];
00869                 vars[var] = size++;
00870 
00871                 Calibration::Variable calibVar;
00872                 calibVar.name = (const char*)var->getName();
00873                 calib->inputSet.push_back(calibVar);
00874                 if (trainCalib)
00875                         trainCalib->addFlag(var->getFlags());
00876         }
00877 
00878         for(std::vector<CalibratedProcessor>::const_iterator iter =
00879                                 procs.begin(); iter != procs.end(); iter++) {
00880                 bool isInterceptor = dynamic_cast<BaseInterceptor*>(
00881                                                         iter->calib) != 0;
00882 
00883                 BitSet inputSet(size);
00884 
00885                 unsigned int last = 0;
00886                 std::vector<SourceVariable*> inoutVars;
00887                 if (iter->processor)
00888                         inoutVars = iter->processor->getInputs().get(
00889                                                                 isInterceptor);
00890                 for(std::vector<SourceVariable*>::const_iterator iter2 =
00891                         inoutVars.begin(); iter2 != inoutVars.end(); iter2++) {
00892                         std::map<SourceVariable*,
00893                                  unsigned int>::const_iterator pos =
00894                                                         vars.find(*iter2);
00895 
00896                         assert(pos != vars.end());
00897 
00898                         if (pos->second < last)
00899                                 throw cms::Exception("MVATrainer")
00900                                         << "Input variables not declared "
00901                                            "in order of appearance in \""
00902                                         << (const char*)iter->processor->getName()
00903                                         << "\"." << std::endl;
00904 
00905                         inputSet[last = pos->second] = true;
00906                 }
00907 
00908                 assert(!isInterceptor || withTarget);
00909 
00910                 iter->calib->inputVars = Calibration::convert(inputSet);
00911 
00912                 calib->output = size;
00913 
00914                 if (isInterceptor) {
00915                         size++;
00916                         continue;
00917                 }
00918 
00919                 calib->addProcessor(iter->calib);
00920 
00921                 inoutVars = iter->processor->getOutputs().get();
00922                 for(std::vector<SourceVariable*>::const_iterator iter =
00923                         inoutVars.begin(); iter != inoutVars.end(); iter++) {
00924 
00925                         vars[*iter] = size++;
00926                 }
00927         }
00928 
00929         if (output->getInputs().size() != 1)
00930                 throw cms::Exception("MVATrainer")
00931                         << "Exactly one output variable has to be specified."
00932                         << std::endl;
00933 
00934         SourceVariable *outVar = output->getInputs().get()[0];
00935         std::map<SourceVariable*, unsigned int>::const_iterator pos =
00936                                                         vars.find(outVar);
00937         if (pos != vars.end())
00938                 calib->output = pos->second;
00939 }

SourceVariable * PhysicsTools::MVATrainer::createVariable ( Source source,
AtomicId  name,
Variable::Flags  flags 
) [private]

Definition at line 702 of file MVATrainer.cc.

References PhysicsTools::Source::getName(), getVariable(), and variables.

Referenced by fillOutputVars(), and MVATrainer().

00704 {
00705         SourceVariable *var = getVariable(source->getName(), name);
00706         if (var)
00707                 return 0;
00708 
00709         var = new SourceVariable(source, name, flags);
00710         variables.push_back(var);
00711         return var;
00712 }

void PhysicsTools::MVATrainer::doneTraining ( Calibration::MVAComputer trainCalibration  )  const

Definition at line 1070 of file MVATrainer.cc.

References calib, and Exception.

01071 {
01072         MVATrainerComputer *calib =
01073                         dynamic_cast<MVATrainerComputer*>(trainCalibration);
01074 
01075         if (!calib)
01076                 throw cms::Exception("MVATrainer")
01077                         << "Invalid training calibration passed to "
01078                            "doneTraining()" << std::endl;
01079 
01080         calib->done();
01081 }

void PhysicsTools::MVATrainer::fillInputVars ( SourceVariableSet vars,
XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *  xml 
) [private]

Definition at line 714 of file MVATrainer.cc.

References PhysicsTools::SourceVariableSet::append(), Exception, find(), PhysicsTools::Source::getOutput(), getVariable(), input, iter, PhysicsTools::SourceVariableSet::kRegular, PhysicsTools::SourceVariableSet::kTarget, kTargetId, PhysicsTools::SourceVariableSet::kWeight, kWeightId, prof2calltree::node, source, tmp, and variables.

Referenced by MVATrainer().

00716 {
00717         std::vector<SourceVariable*> tmp;
00718         SourceVariable *target = 0;
00719         SourceVariable *weight = 0;
00720 
00721         for(DOMNode *node = xml->getFirstChild(); node;
00722             node = node->getNextSibling()) {
00723                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00724                         continue;
00725 
00726                 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
00727                         throw cms::Exception("MVATrainer")
00728                                 << "Invalid input variable node." << std::endl;
00729 
00730                 DOMElement *elem = static_cast<DOMElement*>(node);
00731 
00732                 AtomicId source = XMLDocument::readAttribute<std::string>(
00733                                                         elem, "source");
00734                 AtomicId name = XMLDocument::readAttribute<std::string>(
00735                                                         elem, "name");
00736 
00737                 SourceVariable *var = getVariable(source, name);
00738                 if (!var)
00739                         throw cms::Exception("MVATrainer")
00740                                 << "Input variable " << (const char*)source
00741                                 << ":" << (const char*)name
00742                                 << " not found." << std::endl;
00743 
00744                 if (XMLDocument::readAttribute<bool>(elem, "target", false)) {
00745                         if (target)
00746                                 throw cms::Exception("MVATrainer")
00747                                         << "Target variable defined twice"
00748                                         << std::endl;
00749                         target = var;
00750                 }
00751                 if (XMLDocument::readAttribute<bool>(elem, "weight", false)) {
00752                         if (weight)
00753                                 throw cms::Exception("MVATrainer")
00754                                         << "Weight variable defined twice"
00755                                         << std::endl;
00756                         weight = var;
00757                 }
00758 
00759                 tmp.push_back(var);
00760         }
00761 
00762         if (!weight) {
00763                 weight = input->getOutput(kWeightId);
00764                 assert(weight);
00765                 tmp.insert(tmp.begin() +
00766                                 (target == input->getOutput(kTargetId)),
00767                            1, weight);
00768         }
00769         if (!target) {
00770                 target = input->getOutput(kTargetId);
00771                 assert(target);
00772                 tmp.insert(tmp.begin(), 1, target);
00773         }
00774 
00775         unsigned int n = 0;
00776         for(std::vector<SourceVariable*>::const_iterator iter = variables.begin();
00777             iter != variables.end(); iter++) {
00778                 std::vector<SourceVariable*>::const_iterator pos =
00779                         std::find(tmp.begin(), tmp.end(), *iter);
00780                 if (pos == tmp.end())
00781                         continue;
00782 
00783                 SourceVariableSet::Magic magic;
00784                 if (*iter == target)
00785                         magic = SourceVariableSet::kTarget;
00786                 else if (*iter == weight)
00787                         magic = SourceVariableSet::kWeight;
00788                 else
00789                         magic = SourceVariableSet::kRegular;
00790 
00791                 if (vars.append(*iter, magic, pos - tmp.begin())) {
00792                         AtomicId source = (*iter)->getSource()->getName();
00793                         AtomicId name = (*iter)->getName();
00794                         throw cms::Exception("MVATrainer")
00795                                 << "Input variable " << (const char*)source
00796                                 << ":" << (const char*)name
00797                                 << " defined twice." << std::endl;
00798                 }
00799 
00800                 n++;
00801         }
00802 
00803         assert(tmp.size() == n);
00804 }

void PhysicsTools::MVATrainer::fillOutputVars ( SourceVariableSet vars,
Source source,
XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *  xml 
) [private]

Definition at line 806 of file MVATrainer.cc.

References PhysicsTools::SourceVariableSet::append(), createVariable(), Exception, PhysicsTools::Variable::FLAG_MULTIPLE, PhysicsTools::Variable::FLAG_NONE, PhysicsTools::Variable::FLAG_OPTIONAL, PhysicsTools::Source::getName(), PhysicsTools::isMagic(), and prof2calltree::node.

Referenced by MVATrainer().

00808 {
00809         for(DOMNode *node = xml->getFirstChild(); node;
00810             node = node->getNextSibling()) {
00811                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00812                         continue;
00813 
00814                 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
00815                         throw cms::Exception("MVATrainer")
00816                                 << "Invalid output variable node."
00817                                 << std::endl;
00818 
00819                 DOMElement *elem = static_cast<DOMElement*>(node);
00820 
00821                 AtomicId name = XMLDocument::readAttribute<std::string>(
00822                                                         elem, "name");
00823                 if (!name)
00824                         throw cms::Exception("MVATrainer")
00825                                 << "Output variable tag missing name."
00826                                 << std::endl;
00827                 if (isMagic(name))
00828                         throw cms::Exception("MVATrainer")
00829                                 << "Cannot use magic variable names in output."
00830                                 << std::endl;
00831 
00832                 Variable::Flags flags = Variable::FLAG_NONE;
00833 
00834                 if (XMLDocument::readAttribute<bool>(elem, "optional", true))
00835                         flags = (PhysicsTools::Variable::Flags)
00836                                 (flags | Variable::FLAG_OPTIONAL);
00837 
00838                 if (XMLDocument::readAttribute<bool>(elem, "multiple", true))
00839                         flags = (PhysicsTools::Variable::Flags)
00840                                 (flags | Variable::FLAG_MULTIPLE);
00841 
00842                 SourceVariable *var = createVariable(source, name, flags);
00843                 if (!var || vars.append(var))
00844                         throw cms::Exception("MVATrainer")
00845                                 << "Output variable "
00846                                 << (const char*)source->getName()
00847                                 << ":" << (const char*)name
00848                                 << " defined twice." << std::endl;
00849         }
00850 }

std::vector< AtomicId > PhysicsTools::MVATrainer::findFinalProcessors (  )  const [private]

Definition at line 1083 of file MVATrainer.cc.

References PhysicsTools::SourceVariableSet::get(), PhysicsTools::Source::inputs, iter, output, source, and sources.

Referenced by getCalibration().

01084 {
01085         std::set<Source*> toCheck;
01086         toCheck.insert(output);
01087 
01088         std::set<Source*> done;
01089         while(!toCheck.empty()) {
01090                 Source *source = *toCheck.begin();
01091                 toCheck.erase(toCheck.begin());
01092 
01093                 std::vector<SourceVariable*> inputs = source->inputs.get();
01094                 for(std::vector<SourceVariable*>::const_iterator iter =
01095                                 inputs.begin(); iter != inputs.end(); ++iter) {
01096                         source = (*iter)->getSource();
01097                         if (done.insert(source).second)
01098                                 toCheck.insert(source);
01099                 }
01100         }
01101 
01102         std::vector<AtomicId> result;
01103         for(std::vector<AtomicId>::const_iterator iter = processors.begin();
01104             iter != processors.end(); ++iter) {
01105                 std::map<AtomicId, Source*>::const_iterator pos =
01106                                                         sources.find(*iter);
01107                 if (pos != sources.end() && done.count(pos->second))
01108                         result.push_back(*iter);
01109         }
01110 
01111         return result;
01112 }

void PhysicsTools::MVATrainer::findUntrainedComputers ( std::vector< AtomicId > &  compute,
std::vector< AtomicId > &  train 
) const [private]

Definition at line 1163 of file MVATrainer.cc.

References doMonitoring, PhysicsTools::SourceVariableSet::get(), PhysicsTools::Source::getInputs(), input, PhysicsTools::Source::isTrained(), iter, PhysicsTools::kOutputId, output, proc, and sources.

Referenced by getTrainCalibration().

01165 {
01166         compute.clear();
01167         train.clear();
01168 
01169         std::set<Source*> trainedSources;
01170         trainedSources.insert(input);
01171 
01172         for(std::vector<AtomicId>::const_iterator iter =
01173                 processors.begin(); iter != processors.end(); iter++) {
01174                 std::map<AtomicId, Source*>::const_iterator pos =
01175                                                         sources.find(*iter);
01176                 assert(pos != sources.end());
01177                 TrainProcessor *proc =
01178                                 dynamic_cast<TrainProcessor*>(pos->second);
01179                 assert(proc);
01180 
01181                 bool trainedDeps = true;
01182                 std::vector<SourceVariable*> inputVars =
01183                                         proc->getInputs().get();
01184                 for(std::vector<SourceVariable*>::const_iterator iter2 =
01185                         inputVars.begin(); iter2 != inputVars.end(); iter2++) {
01186                         if (trainedSources.find((*iter2)->getSource())
01187                             == trainedSources.end()) {
01188                                 trainedDeps = false;
01189                                 break;
01190                         }
01191                 }
01192 
01193                 if (!trainedDeps)
01194                         continue;
01195 
01196                 if (proc->isTrained()) {
01197                         trainedSources.insert(proc);
01198                         compute.push_back(proc->getName());
01199                 } else
01200                         train.push_back(proc->getName());
01201         }
01202 
01203         if (doMonitoring && !output->isTrained() &&
01204             trainedSources.find(output->getInputs().get()[0]->getSource())
01205                                                 != trainedSources.end())
01206                 train.push_back(kOutputId);
01207 }

Calibration::MVAComputer * PhysicsTools::MVATrainer::getCalibration (  )  const

Definition at line 1114 of file MVATrainer.cc.

References begin, calib, connectProcessors(), end, find(), findFinalProcessors(), PhysicsTools::TrainProcessor::getCalibration(), int, iter, proc, source, and sources.

Referenced by PhysicsTools::MVATrainerContainerLooperImpl< Record_t >::produce(), and PhysicsTools::TreeTrainer::train().

01115 {
01116         std::vector<CalibratedProcessor> processors;
01117 
01118         std::auto_ptr<Calibration::MVAComputer> calib(
01119                                                 new Calibration::MVAComputer);
01120 
01121         std::vector<AtomicId> used = findFinalProcessors();
01122         for(std::vector<AtomicId>::const_iterator iter = used.begin();
01123             iter != used.end(); iter++) {
01124                 std::map<AtomicId, Source*>::const_iterator pos =
01125                                                         sources.find(*iter);
01126                 assert(pos != sources.end());
01127                 TrainProcessor *source =
01128                                 dynamic_cast<TrainProcessor*>(pos->second);
01129                 assert(source);
01130                 if (!source->isTrained())
01131                         return 0;
01132 
01133                 Calibration::VarProcessor *proc = source->getCalibration();
01134                 if (!proc)
01135                         continue;
01136 
01137                 Calibration::ProcForeach *foreach =
01138                                 dynamic_cast<Calibration::ProcForeach*>(proc);
01139                 if (foreach) {
01140                         std::vector<AtomicId>::const_iterator begin =
01141                                 std::find(this->processors.begin(),
01142                                           this->processors.end(), *iter);
01143                         assert(this->processors.end() - begin >
01144                                (int)(foreach->nProcs + 1));
01145                         ++begin;
01146                         std::vector<AtomicId>::const_iterator end =
01147                                                 begin + foreach->nProcs;
01148                         foreach->nProcs = 0;
01149                         for(std::vector<AtomicId>::const_iterator iter2 =
01150                                         iter; iter2 != used.end(); ++iter2)
01151                                 if (std::find(begin, end, *iter2) != end)
01152                                         foreach->nProcs++;
01153                 }
01154 
01155                 processors.push_back(CalibratedProcessor(source, proc));
01156         }
01157 
01158         connectProcessors(calib.get(), processors, false);
01159 
01160         return calib.release();
01161 }

const std::string& PhysicsTools::MVATrainer::getName ( void   )  const [inline]

Definition at line 52 of file MVATrainer.h.

References name.

00052 { return name; }

Calibration::MVAComputer * PhysicsTools::MVATrainer::getTrainCalibration (  )  const

Definition at line 1209 of file MVATrainer.cc.

References bookConverter::compute(), findUntrainedComputers(), and makeTrainCalibration().

Referenced by PhysicsTools::TreeTrainer::iteration().

01210 {
01211         std::vector<AtomicId> compute, train;
01212         findUntrainedComputers(compute, train);
01213 
01214         if (train.empty())
01215                 return 0;
01216 
01217         compute.push_back(0);
01218         train.push_back(0);
01219 
01220         return makeTrainCalibration(&compute.front(), &train.front());
01221 }

SourceVariable * PhysicsTools::MVATrainer::getVariable ( AtomicId  source,
AtomicId  name 
) const [private]

Definition at line 693 of file MVATrainer.cc.

References sources.

Referenced by createVariable(), and fillInputVars().

00694 {
00695         std::map<AtomicId, Source*>::const_iterator pos = sources.find(source);
00696         if (pos == sources.end())
00697                 return 0;
00698 
00699         return pos->second->getOutput(name);
00700 }

void PhysicsTools::MVATrainer::loadState (  ) 

Definition at line 551 of file MVATrainer.cc.

References iter, source, and sources.

00552 {
00553         for(std::vector<AtomicId>::const_iterator iter =
00554                                                 this->processors.begin();
00555             iter != this->processors.end(); iter++) {
00556                 std::map<AtomicId, Source*>::const_iterator pos =
00557                                                         sources.find(*iter);
00558                 assert(pos != sources.end());
00559                 TrainProcessor *source =
00560                                 dynamic_cast<TrainProcessor*>(pos->second);
00561                 assert(source);
00562 
00563                 if (source->load())
00564                         edm::LogInfo("MVATrainer")
00565                                 << source->getId() << " configuration for \""
00566                                 << (const char*)source->getName()
00567                                 << "\" loaded from file.";
00568         }
00569 }

void PhysicsTools::MVATrainer::makeProcessor ( XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *  elem,
AtomicId  id,
const char *  name 
) [private]

Referenced by MVATrainer().

Calibration::MVAComputer * PhysicsTools::MVATrainer::makeTrainCalibration ( const AtomicId compute,
const AtomicId train 
) const [private]

Definition at line 942 of file MVATrainer.cc.

References calib, connectProcessors(), crossValidation, doAutoSave, find(), PhysicsTools::TrainProcessor::getCalibration(), i, int, iter, PhysicsTools::kOutputId, TtSemiLepJetCombMVATrainTreeSaver_Muons_cff::looper, PhysicsTools::Calibration::ProcForeach::nProcs, output, proc, randomSeed, source, and sources.

Referenced by getTrainCalibration().

00944 {
00945         std::map<AtomicId, TrainInterceptor*> interceptors;
00946         std::vector<MVATrainerComputer::Interceptor> baseInterceptors;
00947         std::vector<CalibratedProcessor> processors;
00948 
00949         BaseInterceptor *interceptor = new InitInterceptor;
00950         baseInterceptors.push_back(std::make_pair(0, interceptor));
00951         processors.push_back(CalibratedProcessor(0, interceptor));
00952 
00953         for(const AtomicId *iter = train; *iter; iter++) {
00954                 TrainProcessor *source;
00955                 if (*iter == kOutputId)
00956                         source = output;
00957                 else {
00958                         std::map<AtomicId, Source*>::const_iterator pos =
00959                                                         sources.find(*iter);
00960                         assert(pos != sources.end());
00961                         source = dynamic_cast<TrainProcessor*>(pos->second);
00962                 }
00963                 assert(source);
00964 
00965                 interceptors[*iter] = new TrainInterceptor(source);
00966         }
00967 
00968         auto_cleaner<Calibration::VarProcessor> autoClean;
00969 
00970         std::set<AtomicId> done;
00971         for(const AtomicId *iter = compute; *iter; iter++) {
00972                 if (done.erase(*iter))
00973                         continue;
00974 
00975                 std::map<AtomicId, Source*>::const_iterator pos =
00976                                                         sources.find(*iter);
00977                 assert(pos != sources.end());
00978                 TrainProcessor *source =
00979                                 dynamic_cast<TrainProcessor*>(pos->second);
00980                 assert(source);
00981                 assert(source->isTrained());
00982 
00983                 Calibration::VarProcessor *proc = source->getCalibration();
00984                 if (!proc)
00985                         continue;
00986 
00987                 autoClean.add(proc);
00988                 processors.push_back(CalibratedProcessor(source, proc));
00989 
00990                 Calibration::ProcForeach *looper =
00991                                 dynamic_cast<Calibration::ProcForeach*>(proc);
00992                 if (looper) {
00993                         std::vector<AtomicId>::const_iterator pos2 =
00994                                 std::find(this->processors.begin(),
00995                                           this->processors.end(), *iter);
00996                         assert(pos2 != this->processors.end());
00997                         ++pos2;
00998                         unsigned int n = 0;
00999                         for(int i = 0; i < (int)looper->nProcs; ++i, ++pos2) {
01000                                 assert(pos2 != this->processors.end());
01001 
01002                                 const AtomicId *iter2 = compute;
01003                                 while(*iter2) {
01004                                         if (*iter2 == *pos2)
01005                                                 break;
01006                                         iter2++;
01007                                 }
01008 
01009                                 if (*iter2) {
01010                                         n++;
01011                                         done.insert(*iter2);
01012                                         pos = sources.find(*iter2);
01013                                         assert(pos != sources.end());
01014                                         TrainProcessor *source =
01015                                                 dynamic_cast<TrainProcessor*>(
01016                                                                 pos->second);
01017                                         assert(source);
01018                                         assert(source->isTrained());
01019 
01020                                         proc = source->getCalibration();
01021                                         if (proc) {
01022                                                 autoClean.add(proc);
01023                                                 processors.push_back(
01024                                                         CalibratedProcessor(
01025                                                                 source, proc));
01026                                         }
01027                                 }
01028 
01029                                 std::map<AtomicId, TrainInterceptor*>::iterator
01030                                                 pos3 = interceptors.find(*pos2);
01031                                 if (pos3 != interceptors.end()) {
01032                                         n++;
01033                                         baseInterceptors.push_back(
01034                                                 std::make_pair(processors.size(),
01035                                                                pos3->second));
01036                                         processors.push_back(
01037                                                 CalibratedProcessor(
01038                                                         pos3->second->getProcessor(),
01039                                                         pos3->second));
01040                                         interceptors.erase(pos3);
01041                                 }
01042                         }
01043 
01044                         looper->nProcs = n;
01045                         if (!n) {
01046                                 baseInterceptors.pop_back();
01047                                 processors.pop_back();
01048                         }
01049                 }
01050         }
01051 
01052         for(std::map<AtomicId, TrainInterceptor*>::const_iterator iter =
01053                 interceptors.begin(); iter != interceptors.end(); ++iter) {
01054 
01055                 TrainProcessor *proc = iter->second->getProcessor();
01056                 baseInterceptors.push_back(std::make_pair(processors.size(),
01057                                                           iter->second));
01058                 processors.push_back(CalibratedProcessor(proc, iter->second));
01059         }
01060 
01061         std::auto_ptr<Calibration::MVAComputer> calib(
01062                 new MVATrainerComputer(baseInterceptors, doAutoSave,
01063                                        randomSeed, crossValidation));
01064 
01065         connectProcessors(calib.get(), processors, true);
01066 
01067         return calib.release();
01068 }

void PhysicsTools::MVATrainer::saveState (  ) 

Definition at line 571 of file MVATrainer.cc.

References doCleanup, iter, source, and sources.

00572 {
00573         doCleanup = false;
00574 
00575         for(std::vector<AtomicId>::const_iterator iter =
00576                                                 this->processors.begin();
00577             iter != this->processors.end(); iter++) {
00578                 std::map<AtomicId, Source*>::const_iterator pos =
00579                                                         sources.find(*iter);
00580                 assert(pos != sources.end());
00581                 TrainProcessor *source =
00582                                 dynamic_cast<TrainProcessor*>(pos->second);
00583                 assert(source);
00584 
00585                 if (source->isTrained())
00586                         source->save();
00587         }
00588 }

void PhysicsTools::MVATrainer::setAutoSave ( bool  autoSave  )  [inline]

Definition at line 32 of file MVATrainer.h.

References doAutoSave.

00032 { doAutoSave = autoSave; }

void PhysicsTools::MVATrainer::setCleanup ( bool  cleanup  )  [inline]

Definition at line 33 of file MVATrainer.h.

References doCleanup.

00033 { doCleanup = cleanup; }

void PhysicsTools::MVATrainer::setCrossValidation ( double  split  )  [inline]

Definition at line 36 of file MVATrainer.h.

References crossValidation.

Referenced by PhysicsTools::TreeTrainer::train().

00036 { crossValidation = split; }

void PhysicsTools::MVATrainer::setMonitoring ( bool  monitoring  )  [inline]

Definition at line 34 of file MVATrainer.h.

References doMonitoring.

Referenced by PhysicsTools::TreeTrainer::train().

00034 { doMonitoring = monitoring; }

void PhysicsTools::MVATrainer::setRandomSeed ( UInt_t  seed  )  [inline]

Definition at line 35 of file MVATrainer.h.

References randomSeed.

00035 { randomSeed = seed; }

std::string PhysicsTools::MVATrainer::trainFileName ( const TrainProcessor proc,
const std::string &  ext,
const std::string &  arg = "" 
) const

Definition at line 668 of file MVATrainer.cc.

References PhysicsTools::Source::getName(), PhysicsTools::stdStringPrintf(), and trainFileMask.

00671 {
00672         std::string arg_ = arg.size() > 0 ? ("_" + arg) : "";
00673         return stdStringPrintf(trainFileMask.c_str(),
00674                                (const char*)proc->getName(),
00675                                arg_.c_str(), ext.c_str());
00676 }


Member Data Documentation

double PhysicsTools::MVATrainer::crossValidation [private]

Definition at line 114 of file MVATrainer.h.

Referenced by makeTrainCalibration(), and setCrossValidation().

bool PhysicsTools::MVATrainer::doAutoSave [private]

Definition at line 109 of file MVATrainer.h.

Referenced by makeTrainCalibration(), and setAutoSave().

bool PhysicsTools::MVATrainer::doCleanup [private]

Definition at line 110 of file MVATrainer.h.

Referenced by saveState(), setCleanup(), and ~MVATrainer().

bool PhysicsTools::MVATrainer::doMonitoring [private]

Definition at line 111 of file MVATrainer.h.

Referenced by bookMonitor(), findUntrainedComputers(), and setMonitoring().

Source* PhysicsTools::MVATrainer::input [private]

Definition at line 102 of file MVATrainer.h.

Referenced by connectProcessors(), fillInputVars(), findUntrainedComputers(), and MVATrainer().

const AtomicId PhysicsTools::MVATrainer::kTargetId [static]

Definition at line 58 of file MVATrainer.h.

Referenced by PhysicsTools::TreeTrainer::addTree(), TtSemiLepJetCombMVATrainer::analyze(), evaluateTtSemiLepSignalSel(), fillInputVars(), PhysicsTools::isMagic(), and MVATrainer().

const AtomicId PhysicsTools::MVATrainer::kWeightId [static]

Definition at line 59 of file MVATrainer.h.

Referenced by PhysicsTools::TreeTrainer::addTree(), evaluateTtSemiLepSignalSel(), fillInputVars(), PhysicsTools::isMagic(), and MVATrainer().

std::auto_ptr<TrainerMonitoring> PhysicsTools::MVATrainer::monitoring [private]

Definition at line 105 of file MVATrainer.h.

Referenced by bookMonitor(), and ~MVATrainer().

std::string PhysicsTools::MVATrainer::name [private]

Definition at line 108 of file MVATrainer.h.

Referenced by getName(), and MVATrainer().

TrainProcessor* PhysicsTools::MVATrainer::output [private]

Definition at line 103 of file MVATrainer.h.

Referenced by connectProcessors(), findFinalProcessors(), findUntrainedComputers(), makeTrainCalibration(), MVATrainer(), and ~MVATrainer().

std::vector<AtomicId> PhysicsTools::MVATrainer::processors [private]

Definition at line 101 of file MVATrainer.h.

UInt_t PhysicsTools::MVATrainer::randomSeed [private]

Definition at line 113 of file MVATrainer.h.

Referenced by makeTrainCalibration(), and setRandomSeed().

std::map<AtomicId, Source*> PhysicsTools::MVATrainer::sources [private]

Definition at line 99 of file MVATrainer.h.

Referenced by findFinalProcessors(), findUntrainedComputers(), getCalibration(), getVariable(), loadState(), makeTrainCalibration(), MVATrainer(), saveState(), and ~MVATrainer().

std::string PhysicsTools::MVATrainer::trainFileMask [private]

Definition at line 107 of file MVATrainer.h.

Referenced by bookMonitor(), MVATrainer(), and trainFileName().

std::vector<SourceVariable*> PhysicsTools::MVATrainer::variables [private]

Definition at line 100 of file MVATrainer.h.

Referenced by connectProcessors(), createVariable(), fillInputVars(), and ~MVATrainer().

std::auto_ptr<XMLDocument> PhysicsTools::MVATrainer::xml [private]

Definition at line 106 of file MVATrainer.h.

Referenced by MVATrainer().


The documentation for this class was generated from the following files:
Generated on Tue Jun 9 18:50:11 2009 for CMSSW by  doxygen 1.5.4