CMS 3D CMS Logo

Classes | Public Member Functions | Static Public Attributes | Private Member Functions | Private Attributes

PhysicsTools::MVATrainer Class Reference

#include <MVATrainer.h>

List of all members.

Classes

struct  CalibratedProcessor

Public Member Functions

TrainerMonitoring::ModulebookMonitor (const std::string &name)
void doneTraining (Calibration::MVAComputer *trainCalibration) const
Calibration::MVAComputergetCalibration () const
const std::string & getName () const
Calibration::MVAComputergetTrainCalibration () const
void loadState ()
 MVATrainer (const std::string &fileName, bool useXSLT=false, const char *styleSheet=0)
void saveState ()
void setAutoSave (bool autoSave)
void setCleanup (bool cleanup)
void setCrossValidation (double split)
void setMonitoring (bool monitoring)
void setRandomSeed (UInt_t seed)
std::string trainFileName (const TrainProcessor *proc, const std::string &ext, const std::string &arg="") const
 ~MVATrainer ()

Static Public Attributes

static const AtomicId kTargetId
static const AtomicId kWeightId

Private Member Functions

void connectProcessors (Calibration::MVAComputer *calib, const std::vector< CalibratedProcessor > &procs, bool withTarget) const
SourceVariablecreateVariable (Source *source, AtomicId name, Variable::Flags flags)
void fillInputVars (SourceVariableSet &vars, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
void fillOutputVars (SourceVariableSet &vars, Source *source, XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *xml)
std::vector< AtomicIdfindFinalProcessors () const
void findUntrainedComputers (std::vector< AtomicId > &compute, std::vector< AtomicId > &train) const
SourceVariablegetVariable (AtomicId source, AtomicId name) const
void makeProcessor (XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *elem, AtomicId id, const char *name)
Calibration::MVAComputermakeTrainCalibration (const AtomicId *compute, const AtomicId *train) const

Private Attributes

double crossValidation
bool doAutoSave
bool doCleanup
bool doMonitoring
Sourceinput
std::auto_ptr< TrainerMonitoringmonitoring
std::string name
TrainProcessoroutput
std::vector< AtomicIdprocessors
UInt_t randomSeed
std::map< AtomicId, Source * > sources
std::string trainFileMask
std::vector< SourceVariable * > variables
std::auto_ptr< XMLDocumentxml

Detailed Description

Definition at line 27 of file MVATrainer.h.


Constructor & Destructor Documentation

PhysicsTools::MVATrainer::MVATrainer ( const std::string &  fileName,
bool  useXSLT = false,
const char *  styleSheet = 0 
)

Definition at line 422 of file MVATrainer.cc.

References createVariable(), PhysicsTools::escape(), Exception, fillInputVars(), fillOutputVars(), PhysicsTools::Variable::FLAG_NONE, PhysicsTools::Variable::FLAG_OPTIONAL, edm::FileInPath::fullPath(), PhysicsTools::Source::getInputs(), input, PhysicsTools::SourceVariableSet::kTarget, kTargetId, PhysicsTools::SourceVariableSet::kWeight, kWeightId, makeProcessor(), name, python::Node::node, output, popcon_last_value_cfg::Source, sources, trainFileMask, xml, and zero.

                                :
        input(0), output(0), name("MVATrainer"),
        doAutoSave(true), doCleanup(false),
        doMonitoring(false), randomSeed(65539), crossValidation(0.0)
{
        if (useXSLT) {
                std::string sheet;
                if (!styleSheet)
                        sheet = edm::FileInPath(
                                "PhysicsTools/MVATrainer/data/MVATrainer.xsl")
                                .fullPath();
                else
                        sheet = styleSheet;

                std::string preproc = "xsltproc --xinclude " + escape(sheet) +
                                      " " + escape(fileName);
                xml.reset(new XMLDocument(fileName, preproc));
        } else
                xml.reset(new XMLDocument(fileName));

        DOMNode *node = xml->getRootNode();

        if (std::strcmp(XMLSimpleStr(node->getNodeName()), "MVATrainer") != 0)
                throw cms::Exception("MVATrainer")
                        << "Invalid XML root node." << std::endl;

        enum State {
                STATE_GENERAL,
                STATE_FIRST,
                STATE_MIDDLE,
                STATE_LAST
        } state = STATE_GENERAL;

        for(node = node->getFirstChild();
            node; node = node->getNextSibling()) {
                if (node->getNodeType() != DOMNode::ELEMENT_NODE)
                        continue;

                std::string name = XMLSimpleStr(node->getNodeName());
                DOMElement *elem = static_cast<DOMElement*>(node);

                switch(state) {
                    case STATE_GENERAL: {
                        if (name != "general")
                                throw cms::Exception("MVATrainer")
                                        << "Expected general config as first "
                                           "tag." << std::endl;

                        for(DOMNode *subNode = elem->getFirstChild();
                            subNode; subNode = subNode->getNextSibling()) {
                                if (subNode->getNodeType() !=
                                    DOMNode::ELEMENT_NODE)
                                        continue;

                                if (std::strcmp(XMLSimpleStr(
                                        subNode->getNodeName()), "option") != 0)
                                        throw cms::Exception("MVATrainer")
                                                << "Expected option tag."
                                                << std::endl;

                                elem = static_cast<DOMElement*>(subNode);
                                name = XMLDocument::readAttribute<std::string>(
                                                                elem, "name");
                                std::string content = XMLSimpleStr(
                                                elem->getTextContent());

                                if (name == "id")
                                        this->name = content;
                                else if (name == "trainfiles")
                                        trainFileMask = content;
                                else
                                        throw cms::Exception("MVATrainer")
                                                << "Unknown option \""
                                                << name << "\"." << std::endl;
                        }

                        state = STATE_FIRST;
                    }   break;
                    case STATE_FIRST: {
                        if (name != "input")
                                throw cms::Exception("MVATrainer")
                                        << "Expected input config as second "
                                           "tag." << std::endl;

                        AtomicId id = XMLDocument::readAttribute<std::string>(
                                                                elem, "id");
                        input = new Source(id, true);
                        input->getOutputs().append(
                                createVariable(input, kTargetId,
                                               Variable::FLAG_NONE),
                                SourceVariableSet::kTarget);
                        input->getOutputs().append(
                                createVariable(input, kWeightId,
                                               Variable::FLAG_OPTIONAL),
                                SourceVariableSet::kWeight);
                        sources.insert(std::make_pair(id, input));
                        fillOutputVars(input->getOutputs(), input, elem);

                        state = STATE_MIDDLE;
                    }   break;
                    case STATE_MIDDLE: {
                        if (name == "output") {
                                AtomicId zero;
                                output = new TrainProcessor("output",
                                                            &zero, this);
                                fillInputVars(output->getInputs(), elem);
                                state = STATE_LAST;
                                continue;
                        } else if (name != "processor")
                                throw cms::Exception("MVATrainer")
                                        << "Unexpected tag after input "
                                           "config." << std::endl;

                        AtomicId id = XMLDocument::readAttribute<std::string>(
                                                                elem, "id");
                        std::string name =
                                XMLDocument::readAttribute<std::string>(
                                        elem, "name");

                        makeProcessor(elem, id, name.c_str());
                    }   break;
                    case STATE_LAST:
                        throw cms::Exception("MVATrainer")
                                << "Unexpected tag found after output."
                                << std::endl;
                        break;
                }
        }

        if (state == STATE_FIRST)
                throw cms::Exception("MVATrainer")
                        << "Expected input variable config." << std::endl;
        else if (state == STATE_MIDDLE)
                throw cms::Exception("MVATrainer")
                        << "Expected output variable config." << std::endl;

        if (trainFileMask.empty())
                trainFileMask = this->name + "_%s%s.%s";
}
PhysicsTools::MVATrainer::~MVATrainer ( )

Definition at line 563 of file MVATrainer.cc.

References PhysicsTools::TrainProcessor::cleanup(), doCleanup, monitoring, output, proc, sources, and variables.

{
        if (monitoring.get())
                monitoring->write();

        for(std::map<AtomicId, Source*>::const_iterator iter = sources.begin();
            iter != sources.end(); iter++) {
                TrainProcessor *proc =
                                dynamic_cast<TrainProcessor*>(iter->second);

                if (proc && doCleanup)
                        proc->cleanup();

                delete iter->second;
        }
        delete output;
        std::for_each(variables.begin(), variables.end(),
                      deleter<SourceVariable>());
}

Member Function Documentation

TrainerMonitoring::Module * PhysicsTools::MVATrainer::bookMonitor ( const std::string &  name)

Definition at line 710 of file MVATrainer.cc.

References doMonitoring, convertXMLtoSQLite_cfg::fileName, monitoring, PhysicsTools::stdStringPrintf(), and trainFileMask.

Referenced by PhysicsTools::TrainProcessor::doTrainBegin().

{
        if (!doMonitoring)
                return 0;

        if (!monitoring.get()) {
                std::string fileName = 
                        stdStringPrintf(trainFileMask.c_str(),
                                        "monitoring", "", "root");
                monitoring.reset(new TrainerMonitoring(fileName));
        }

        return monitoring->book(name);
}
void PhysicsTools::MVATrainer::connectProcessors ( Calibration::MVAComputer calib,
const std::vector< CalibratedProcessor > &  procs,
bool  withTarget 
) const [private]

Definition at line 885 of file MVATrainer.cc.

References PhysicsTools::Calibration::MVAComputer::addProcessor(), calib, lhef::cc::convert(), Exception, PhysicsTools::SourceVariableSet::get(), PhysicsTools::Variable::getFlags(), PhysicsTools::Source::getInputs(), PhysicsTools::Variable::getName(), PhysicsTools::Source::getOutputs(), i, input, PhysicsTools::Calibration::MVAComputer::inputSet, prof2calltree::last, Association::map, PhysicsTools::Calibration::Variable::name, output, PhysicsTools::Calibration::MVAComputer::output, findQualityFiles::size, PhysicsTools::SourceVariableSet::size(), and variables.

Referenced by getCalibration(), and makeTrainCalibration().

{
        std::map<SourceVariable*, unsigned int> vars;
        unsigned int size = 0;

        MVATrainerComputer *trainCalib =
                        dynamic_cast<MVATrainerComputer*>(calib);

        for(unsigned int i = 0;
            i < input->getOutputs().size(true); i++) {
                if (i < 2 && !withTarget)
                        continue;

                SourceVariable *var = variables[i];
                vars[var] = size++;

                Calibration::Variable calibVar;
                calibVar.name = (const char*)var->getName();
                calib->inputSet.push_back(calibVar);
                if (trainCalib)
                        trainCalib->addFlag(var->getFlags());
        }

        for(std::vector<CalibratedProcessor>::const_iterator iter =
                                procs.begin(); iter != procs.end(); iter++) {
                bool isInterceptor = dynamic_cast<BaseInterceptor*>(
                                                        iter->calib) != 0;

                BitSet inputSet(size);

                unsigned int last = 0;
                std::vector<SourceVariable*> inoutVars;
                if (iter->processor)
                        inoutVars = iter->processor->getInputs().get(
                                                                isInterceptor);
                for(std::vector<SourceVariable*>::const_iterator iter2 =
                        inoutVars.begin(); iter2 != inoutVars.end(); iter2++) {
                        std::map<SourceVariable*,
                                 unsigned int>::const_iterator pos =
                                                        vars.find(*iter2);

                        assert(pos != vars.end());

                        if (pos->second < last)
                                throw cms::Exception("MVATrainer")
                                        << "Input variables not declared "
                                           "in order of appearance in \""
                                        << (const char*)iter->processor->getName()
                                        << "\"." << std::endl;

                        inputSet[last = pos->second] = true;
                }

                assert(!isInterceptor || withTarget);

                iter->calib->inputVars = Calibration::convert(inputSet);

                calib->output = size;

                if (isInterceptor) {
                        size++;
                        continue;
                }

                calib->addProcessor(iter->calib);

                inoutVars = iter->processor->getOutputs().get();
                for(std::vector<SourceVariable*>::const_iterator iter =
                        inoutVars.begin(); iter != inoutVars.end(); iter++) {

                        vars[*iter] = size++;
                }
        }

        if (output->getInputs().size() != 1)
                throw cms::Exception("MVATrainer")
                        << "Exactly one output variable has to be specified."
                        << std::endl;

        SourceVariable *outVar = output->getInputs().get()[0];
        std::map<SourceVariable*, unsigned int>::const_iterator pos =
                                                        vars.find(outVar);
        if (pos != vars.end())
                calib->output = pos->second;
}
SourceVariable * PhysicsTools::MVATrainer::createVariable ( Source source,
AtomicId  name,
Variable::Flags  flags 
) [private]

Definition at line 734 of file MVATrainer.cc.

References PhysicsTools::Source::getName(), getVariable(), name, and variables.

Referenced by fillOutputVars(), and MVATrainer().

{
        SourceVariable *var = getVariable(source->getName(), name);
        if (var)
                return 0;

        var = new SourceVariable(source, name, flags);
        variables.push_back(var);
        return var;
}
void PhysicsTools::MVATrainer::doneTraining ( Calibration::MVAComputer trainCalibration) const

Definition at line 1102 of file MVATrainer.cc.

References calib, and Exception.

{
        MVATrainerComputer *calib =
                        dynamic_cast<MVATrainerComputer*>(trainCalibration);

        if (!calib)
                throw cms::Exception("MVATrainer")
                        << "Invalid training calibration passed to "
                           "doneTraining()" << std::endl;

        calib->done();
}
void PhysicsTools::MVATrainer::fillInputVars ( SourceVariableSet vars,
XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *  xml 
) [private]

Definition at line 746 of file MVATrainer.cc.

References PhysicsTools::SourceVariableSet::append(), Exception, spr::find(), PhysicsTools::Source::getOutput(), getVariable(), input, PhysicsTools::SourceVariableSet::kRegular, PhysicsTools::SourceVariableSet::kTarget, kTargetId, PhysicsTools::SourceVariableSet::kWeight, kWeightId, n, name, python::Node::node, LaserTracksInput_cfi::source, filterCSVwithJSON::target, tmp, and variables.

Referenced by makeProcessor(), and MVATrainer().

{
        std::vector<SourceVariable*> tmp;
        SourceVariable *target = 0;
        SourceVariable *weight = 0;

        for(DOMNode *node = xml->getFirstChild(); node;
            node = node->getNextSibling()) {
                if (node->getNodeType() != DOMNode::ELEMENT_NODE)
                        continue;

                if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
                        throw cms::Exception("MVATrainer")
                                << "Invalid input variable node." << std::endl;

                DOMElement *elem = static_cast<DOMElement*>(node);

                AtomicId source = XMLDocument::readAttribute<std::string>(
                                                        elem, "source");
                AtomicId name = XMLDocument::readAttribute<std::string>(
                                                        elem, "name");

                SourceVariable *var = getVariable(source, name);
                if (!var)
                        throw cms::Exception("MVATrainer")
                                << "Input variable " << (const char*)source
                                << ":" << (const char*)name
                                << " not found." << std::endl;

                if (XMLDocument::readAttribute<bool>(elem, "target", false)) {
                        if (target)
                                throw cms::Exception("MVATrainer")
                                        << "Target variable defined twice"
                                        << std::endl;
                        target = var;
                }
                if (XMLDocument::readAttribute<bool>(elem, "weight", false)) {
                        if (weight)
                                throw cms::Exception("MVATrainer")
                                        << "Weight variable defined twice"
                                        << std::endl;
                        weight = var;
                }

                tmp.push_back(var);
        }

        if (!weight) {
                weight = input->getOutput(kWeightId);
                assert(weight);
                tmp.insert(tmp.begin() +
                                (target == input->getOutput(kTargetId)),
                           1, weight);
        }
        if (!target) {
                target = input->getOutput(kTargetId);
                assert(target);
                tmp.insert(tmp.begin(), 1, target);
        }

        unsigned int n = 0;
        for(std::vector<SourceVariable*>::const_iterator iter = variables.begin();
            iter != variables.end(); iter++) {
                std::vector<SourceVariable*>::const_iterator pos =
                        std::find(tmp.begin(), tmp.end(), *iter);
                if (pos == tmp.end())
                        continue;

                SourceVariableSet::Magic magic;
                if (*iter == target)
                        magic = SourceVariableSet::kTarget;
                else if (*iter == weight)
                        magic = SourceVariableSet::kWeight;
                else
                        magic = SourceVariableSet::kRegular;

                if (vars.append(*iter, magic, pos - tmp.begin())) {
                        AtomicId source = (*iter)->getSource()->getName();
                        AtomicId name = (*iter)->getName();
                        throw cms::Exception("MVATrainer")
                                << "Input variable " << (const char*)source
                                << ":" << (const char*)name
                                << " defined twice." << std::endl;
                }

                n++;
        }

        assert(tmp.size() == n);
}
void PhysicsTools::MVATrainer::fillOutputVars ( SourceVariableSet vars,
Source source,
XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *  xml 
) [private]

Definition at line 838 of file MVATrainer.cc.

References PhysicsTools::SourceVariableSet::append(), createVariable(), Exception, PhysicsTools::Variable::FLAG_MULTIPLE, PhysicsTools::Variable::FLAG_NONE, PhysicsTools::Variable::FLAG_OPTIONAL, PhysicsTools::Source::getName(), PhysicsTools::isMagic(), name, and python::Node::node.

Referenced by makeProcessor(), and MVATrainer().

{
        for(DOMNode *node = xml->getFirstChild(); node;
            node = node->getNextSibling()) {
                if (node->getNodeType() != DOMNode::ELEMENT_NODE)
                        continue;

                if (std::strcmp(XMLSimpleStr(node->getNodeName()), "var") != 0)
                        throw cms::Exception("MVATrainer")
                                << "Invalid output variable node."
                                << std::endl;

                DOMElement *elem = static_cast<DOMElement*>(node);

                AtomicId name = XMLDocument::readAttribute<std::string>(
                                                        elem, "name");
                if (!name)
                        throw cms::Exception("MVATrainer")
                                << "Output variable tag missing name."
                                << std::endl;
                if (isMagic(name))
                        throw cms::Exception("MVATrainer")
                                << "Cannot use magic variable names in output."
                                << std::endl;

                Variable::Flags flags = Variable::FLAG_NONE;

                if (XMLDocument::readAttribute<bool>(elem, "optional", true))
                        flags = (PhysicsTools::Variable::Flags)
                                (flags | Variable::FLAG_OPTIONAL);

                if (XMLDocument::readAttribute<bool>(elem, "multiple", true))
                        flags = (PhysicsTools::Variable::Flags)
                                (flags | Variable::FLAG_MULTIPLE);

                SourceVariable *var = createVariable(source, name, flags);
                if (!var || vars.append(var))
                        throw cms::Exception("MVATrainer")
                                << "Output variable "
                                << (const char*)source->getName()
                                << ":" << (const char*)name
                                << " defined twice." << std::endl;
        }
}
std::vector< AtomicId > PhysicsTools::MVATrainer::findFinalProcessors ( ) const [private]

Definition at line 1115 of file MVATrainer.cc.

References generateEDF::done, PhysicsTools::SourceVariableSet::get(), PhysicsTools::Source::inputs, output, query::result, LaserTracksInput_cfi::source, and sources.

Referenced by getCalibration().

{
        std::set<Source*> toCheck;
        toCheck.insert(output);

        std::set<Source*> done;
        while(!toCheck.empty()) {
                Source *source = *toCheck.begin();
                toCheck.erase(toCheck.begin());

                std::vector<SourceVariable*> inputs = source->inputs.get();
                for(std::vector<SourceVariable*>::const_iterator iter =
                                inputs.begin(); iter != inputs.end(); ++iter) {
                        source = (*iter)->getSource();
                        if (done.insert(source).second)
                                toCheck.insert(source);
                }
        }

        std::vector<AtomicId> result;
        for(std::vector<AtomicId>::const_iterator iter = processors.begin();
            iter != processors.end(); ++iter) {
                std::map<AtomicId, Source*>::const_iterator pos =
                                                        sources.find(*iter);
                if (pos != sources.end() && done.count(pos->second))
                        result.push_back(*iter);
        }

        return result;
}
void PhysicsTools::MVATrainer::findUntrainedComputers ( std::vector< AtomicId > &  compute,
std::vector< AtomicId > &  train 
) const [private]

Definition at line 1195 of file MVATrainer.cc.

References doMonitoring, PhysicsTools::SourceVariableSet::get(), PhysicsTools::Source::getInputs(), input, PhysicsTools::Source::isTrained(), PhysicsTools::kOutputId, output, proc, and sources.

Referenced by getTrainCalibration().

{
        compute.clear();
        train.clear();

        std::set<Source*> trainedSources;
        trainedSources.insert(input);

        for(std::vector<AtomicId>::const_iterator iter =
                processors.begin(); iter != processors.end(); iter++) {
                std::map<AtomicId, Source*>::const_iterator pos =
                                                        sources.find(*iter);
                assert(pos != sources.end());
                TrainProcessor *proc =
                                dynamic_cast<TrainProcessor*>(pos->second);
                assert(proc);

                bool trainedDeps = true;
                std::vector<SourceVariable*> inputVars =
                                        proc->getInputs().get();
                for(std::vector<SourceVariable*>::const_iterator iter2 =
                        inputVars.begin(); iter2 != inputVars.end(); iter2++) {
                        if (trainedSources.find((*iter2)->getSource())
                            == trainedSources.end()) {
                                trainedDeps = false;
                                break;
                        }
                }

                if (!trainedDeps)
                        continue;

                if (proc->isTrained()) {
                        trainedSources.insert(proc);
                        compute.push_back(proc->getName());
                } else
                        train.push_back(proc->getName());
        }

        if (doMonitoring && !output->isTrained() &&
            trainedSources.find(output->getInputs().get()[0]->getSource())
                                                != trainedSources.end())
                train.push_back(kOutputId);
}
Calibration::MVAComputer * PhysicsTools::MVATrainer::getCalibration ( ) const

Definition at line 1146 of file MVATrainer.cc.

References begin, calib, connectProcessors(), end, spr::find(), findFinalProcessors(), PhysicsTools::TrainProcessor::getCalibration(), proc, processors, LaserTracksInput_cfi::source, and sources.

Referenced by PhysicsTools::MVATrainerContainerLooperImpl< Record_t >::produce(), and PhysicsTools::TreeTrainer::train().

{
        std::vector<CalibratedProcessor> processors;

        std::auto_ptr<Calibration::MVAComputer> calib(
                                                new Calibration::MVAComputer);

        std::vector<AtomicId> used = findFinalProcessors();
        for(std::vector<AtomicId>::const_iterator iter = used.begin();
            iter != used.end(); iter++) {
                std::map<AtomicId, Source*>::const_iterator pos =
                                                        sources.find(*iter);
                assert(pos != sources.end());
                TrainProcessor *source =
                                dynamic_cast<TrainProcessor*>(pos->second);
                assert(source);
                if (!source->isTrained())
                        return 0;

                Calibration::VarProcessor *proc = source->getCalibration();
                if (!proc)
                        continue;

                Calibration::ProcForeach *foreach =
                                dynamic_cast<Calibration::ProcForeach*>(proc);
                if (foreach) {
                        std::vector<AtomicId>::const_iterator begin =
                                std::find(this->processors.begin(),
                                          this->processors.end(), *iter);
                        assert(this->processors.end() - begin >
                               (int)(foreach->nProcs + 1));
                        ++begin;
                        std::vector<AtomicId>::const_iterator end =
                                                begin + foreach->nProcs;
                        foreach->nProcs = 0;
                        for(std::vector<AtomicId>::const_iterator iter2 =
                                        iter; iter2 != used.end(); ++iter2)
                                if (std::find(begin, end, *iter2) != end)
                                        foreach->nProcs++;
                }

                processors.push_back(CalibratedProcessor(source, proc));
        }

        connectProcessors(calib.get(), processors, false);

        return calib.release();
}
const std::string& PhysicsTools::MVATrainer::getName ( void  ) const [inline]

Definition at line 53 of file MVATrainer.h.

References name.

{ return name; }
Calibration::MVAComputer * PhysicsTools::MVATrainer::getTrainCalibration ( ) const

Definition at line 1241 of file MVATrainer.cc.

References bookConverter::compute(), findUntrainedComputers(), and makeTrainCalibration().

Referenced by PhysicsTools::TreeTrainer::iteration().

{
        std::vector<AtomicId> compute, train;
        findUntrainedComputers(compute, train);

        if (train.empty())
                return 0;

        compute.push_back(0);
        train.push_back(0);

        return makeTrainCalibration(&compute.front(), &train.front());
}
SourceVariable * PhysicsTools::MVATrainer::getVariable ( AtomicId  source,
AtomicId  name 
) const [private]

Definition at line 725 of file MVATrainer.cc.

References sources.

Referenced by createVariable(), and fillInputVars().

{
        std::map<AtomicId, Source*>::const_iterator pos = sources.find(source);
        if (pos == sources.end())
                return 0;

        return pos->second->getOutput(name);
}
void PhysicsTools::MVATrainer::loadState ( )

Definition at line 583 of file MVATrainer.cc.

References LaserTracksInput_cfi::source, and sources.

{
        for(std::vector<AtomicId>::const_iterator iter =
                                                this->processors.begin();
            iter != this->processors.end(); iter++) {
                std::map<AtomicId, Source*>::const_iterator pos =
                                                        sources.find(*iter);
                assert(pos != sources.end());
                TrainProcessor *source =
                                dynamic_cast<TrainProcessor*>(pos->second);
                assert(source);

                if (source->load())
                        edm::LogInfo("MVATrainer")
                                << source->getId() << " configuration for \""
                                << (const char*)source->getName()
                                << "\" loaded from file.";
        }
}
void PhysicsTools::MVATrainer::makeProcessor ( XERCES_CPP_NAMESPACE_QUALIFIER DOMElement *  elem,
AtomicId  id,
const char *  name 
) [private]

Definition at line 622 of file MVATrainer.cc.

References PhysicsTools::ProcessRegistry< Base_t, CalibBase_t, Parent_t >::Factory::create(), Exception, fillInputVars(), fillOutputVars(), python::Node::node, proc, sources, and GlobalPosition_Frontier_DevDB_cff::tag.

Referenced by MVATrainer().

{
        DOMElement *xmlInput = 0;
        DOMElement *xmlConfig = 0;
        DOMElement *xmlOutput = 0;
        DOMElement *xmlData = 0;

        static struct NameExpect {
                const char      *tag;
                bool            mandatory;
                DOMElement      **elem;
        } const expect[] = {
                { "input",      true,   &xmlInput },
                { "config",     true,   &xmlConfig },
                { "output",     true,   &xmlOutput },
                { "data",       false,  &xmlData },
                { 0, }
        };

        const NameExpect *cur = expect;
        for(DOMNode *node = elem->getFirstChild();
            node; node = node->getNextSibling()) {
                if (node->getNodeType() != DOMNode::ELEMENT_NODE)
                        continue;

                std::string tag = XMLSimpleStr(node->getNodeName());
                DOMElement *elem = static_cast<DOMElement*>(node);

                if (!cur->tag)
                        throw cms::Exception("MVATrainer")
                                << "Superfluous tag " << tag
                                << "encountered in processor." << std::endl;
                else if (tag != cur->tag && cur->mandatory)
                        throw cms::Exception("MVATrainer")
                                << "Expected tag " << cur->tag << ", got "
                                << tag << " instead in processor."
                                << std::endl;
                else if (tag != cur->tag) {
                        cur++;
                        continue;
                }
                *(cur++)->elem = elem;
        }

        while(cur->tag && !cur->mandatory)
                cur++;
        if (cur->tag)
                throw cms::Exception("MVATrainer")
                        << "Unexpected end of processor configuration, "
                        << "expected tag " << cur->tag << "." << std::endl;

        std::auto_ptr<TrainProcessor> proc(
                                TrainProcessor::create(name, &id, this));
        if (!proc.get())
                throw cms::Exception("MVATrainer")
                        << "Variable processor trainer " << name
                        << " could not be instantiated. Most likely because"
                           " the trainer plugin for \"" << name << "\""
                           " does not exist." << std::endl;

        if (sources.find(id) != sources.end())
                throw cms::Exception("MVATrainer")
                        << "Duplicate variable processor id "
                        << (const char*)id << "."
                        << std::endl;

        fillInputVars(proc->getInputs(), xmlInput);
        fillOutputVars(proc->getOutputs(), proc.get(), xmlOutput);

        edm::LogInfo("MVATrainer")
                << "Configuring " << (const char*)proc->getId()
                << " \"" << (const char*)proc->getName() << "\".";
        proc->configure(xmlConfig);

        sources.insert(std::make_pair(id, proc.release()));
        processors.push_back(id);
}
Calibration::MVAComputer * PhysicsTools::MVATrainer::makeTrainCalibration ( const AtomicId compute,
const AtomicId train 
) const [private]

Definition at line 974 of file MVATrainer.cc.

References calib, bookConverter::compute(), connectProcessors(), crossValidation, doAutoSave, generateEDF::done, spr::find(), PhysicsTools::TrainProcessor::getCalibration(), i, interceptors, PhysicsTools::kOutputId, AlignmentProducer_cff::looper, n, PhysicsTools::Calibration::ProcForeach::nProcs, output, proc, processors, randomSeed, LaserTracksInput_cfi::source, and sources.

Referenced by getTrainCalibration().

{
        std::map<AtomicId, TrainInterceptor*> interceptors;
        std::vector<MVATrainerComputer::Interceptor> baseInterceptors;
        std::vector<CalibratedProcessor> processors;

        BaseInterceptor *interceptor = new InitInterceptor;
        baseInterceptors.push_back(std::make_pair(0, interceptor));
        processors.push_back(CalibratedProcessor(0, interceptor));

        for(const AtomicId *iter = train; *iter; iter++) {
                TrainProcessor *source;
                if (*iter == kOutputId)
                        source = output;
                else {
                        std::map<AtomicId, Source*>::const_iterator pos =
                                                        sources.find(*iter);
                        assert(pos != sources.end());
                        source = dynamic_cast<TrainProcessor*>(pos->second);
                }
                assert(source);

                interceptors[*iter] = new TrainInterceptor(source);
        }

        auto_cleaner<Calibration::VarProcessor> autoClean;

        std::set<AtomicId> done;
        for(const AtomicId *iter = compute; *iter; iter++) {
                if (done.erase(*iter))
                        continue;

                std::map<AtomicId, Source*>::const_iterator pos =
                                                        sources.find(*iter);
                assert(pos != sources.end());
                TrainProcessor *source =
                                dynamic_cast<TrainProcessor*>(pos->second);
                assert(source);
                assert(source->isTrained());

                Calibration::VarProcessor *proc = source->getCalibration();
                if (!proc)
                        continue;

                autoClean.add(proc);
                processors.push_back(CalibratedProcessor(source, proc));

                Calibration::ProcForeach *looper =
                                dynamic_cast<Calibration::ProcForeach*>(proc);
                if (looper) {
                        std::vector<AtomicId>::const_iterator pos2 =
                                std::find(this->processors.begin(),
                                          this->processors.end(), *iter);
                        assert(pos2 != this->processors.end());
                        ++pos2;
                        unsigned int n = 0;
                        for(int i = 0; i < (int)looper->nProcs; ++i, ++pos2) {
                                assert(pos2 != this->processors.end());

                                const AtomicId *iter2 = compute;
                                while(*iter2) {
                                        if (*iter2 == *pos2)
                                                break;
                                        iter2++;
                                }

                                if (*iter2) {
                                        n++;
                                        done.insert(*iter2);
                                        pos = sources.find(*iter2);
                                        assert(pos != sources.end());
                                        TrainProcessor *source =
                                                dynamic_cast<TrainProcessor*>(
                                                                pos->second);
                                        assert(source);
                                        assert(source->isTrained());

                                        proc = source->getCalibration();
                                        if (proc) {
                                                autoClean.add(proc);
                                                processors.push_back(
                                                        CalibratedProcessor(
                                                                source, proc));
                                        }
                                }

                                std::map<AtomicId, TrainInterceptor*>::iterator
                                                pos3 = interceptors.find(*pos2);
                                if (pos3 != interceptors.end()) {
                                        n++;
                                        baseInterceptors.push_back(
                                                std::make_pair(processors.size(),
                                                               pos3->second));
                                        processors.push_back(
                                                CalibratedProcessor(
                                                        pos3->second->getProcessor(),
                                                        pos3->second));
                                        interceptors.erase(pos3);
                                }
                        }

                        looper->nProcs = n;
                        if (!n) {
                                baseInterceptors.pop_back();
                                processors.pop_back();
                        }
                }
        }

        for(std::map<AtomicId, TrainInterceptor*>::const_iterator iter =
                interceptors.begin(); iter != interceptors.end(); ++iter) {

                TrainProcessor *proc = iter->second->getProcessor();
                baseInterceptors.push_back(std::make_pair(processors.size(),
                                                          iter->second));
                processors.push_back(CalibratedProcessor(proc, iter->second));
        }

        std::auto_ptr<Calibration::MVAComputer> calib(
                new MVATrainerComputer(baseInterceptors, doAutoSave,
                                       randomSeed, crossValidation));

        connectProcessors(calib.get(), processors, true);

        return calib.release();
}
void PhysicsTools::MVATrainer::saveState ( )

Definition at line 603 of file MVATrainer.cc.

References doCleanup, LaserTracksInput_cfi::source, and sources.

{
        doCleanup = false;

        for(std::vector<AtomicId>::const_iterator iter =
                                                this->processors.begin();
            iter != this->processors.end(); iter++) {
                std::map<AtomicId, Source*>::const_iterator pos =
                                                        sources.find(*iter);
                assert(pos != sources.end());
                TrainProcessor *source =
                                dynamic_cast<TrainProcessor*>(pos->second);
                assert(source);

                if (source->isTrained())
                        source->save();
        }
}
void PhysicsTools::MVATrainer::setAutoSave ( bool  autoSave) [inline]

Definition at line 33 of file MVATrainer.h.

References doAutoSave.

{ doAutoSave = autoSave; }
void PhysicsTools::MVATrainer::setCleanup ( bool  cleanup) [inline]

Definition at line 34 of file MVATrainer.h.

References edm::cleanup(), and doCleanup.

void PhysicsTools::MVATrainer::setCrossValidation ( double  split) [inline]

Definition at line 37 of file MVATrainer.h.

References crossValidation, and PhysicsTools::split().

Referenced by PhysicsTools::TreeTrainer::train().

void PhysicsTools::MVATrainer::setMonitoring ( bool  monitoring) [inline]

Definition at line 35 of file MVATrainer.h.

References doMonitoring, and monitoring.

Referenced by PhysicsTools::TreeTrainer::train().

void PhysicsTools::MVATrainer::setRandomSeed ( UInt_t  seed) [inline]

Definition at line 36 of file MVATrainer.h.

References randomSeed.

{ randomSeed = seed; }
std::string PhysicsTools::MVATrainer::trainFileName ( const TrainProcessor proc,
const std::string &  ext,
const std::string &  arg = "" 
) const

Definition at line 700 of file MVATrainer.cc.

References cmsCodeRulesChecker::arg, PhysicsTools::Source::getName(), PhysicsTools::stdStringPrintf(), and trainFileMask.

{
        std::string arg_ = arg.size() > 0 ? ("_" + arg) : "";
        return stdStringPrintf(trainFileMask.c_str(),
                               (const char*)proc->getName(),
                               arg_.c_str(), ext.c_str());
}

Member Data Documentation

Definition at line 115 of file MVATrainer.h.

Referenced by makeTrainCalibration(), and setCrossValidation().

Definition at line 110 of file MVATrainer.h.

Referenced by makeTrainCalibration(), and setAutoSave().

Definition at line 111 of file MVATrainer.h.

Referenced by saveState(), setCleanup(), and ~MVATrainer().

Definition at line 112 of file MVATrainer.h.

Referenced by bookMonitor(), findUntrainedComputers(), and setMonitoring().

Definition at line 103 of file MVATrainer.h.

Referenced by connectProcessors(), fillInputVars(), findUntrainedComputers(), and MVATrainer().

Definition at line 106 of file MVATrainer.h.

Referenced by bookMonitor(), setMonitoring(), and ~MVATrainer().

std::string PhysicsTools::MVATrainer::name [private]

Definition at line 109 of file MVATrainer.h.

Referenced by createVariable(), fillInputVars(), fillOutputVars(), getName(), and MVATrainer().

Definition at line 102 of file MVATrainer.h.

Referenced by getCalibration(), and makeTrainCalibration().

Definition at line 114 of file MVATrainer.h.

Referenced by makeTrainCalibration(), and setRandomSeed().

Definition at line 108 of file MVATrainer.h.

Referenced by bookMonitor(), MVATrainer(), and trainFileName().

Definition at line 101 of file MVATrainer.h.

Referenced by connectProcessors(), createVariable(), fillInputVars(), and ~MVATrainer().

std::auto_ptr<XMLDocument> PhysicsTools::MVATrainer::xml [private]

Definition at line 107 of file MVATrainer.h.

Referenced by MVATrainer().