CMS 3D CMS Logo

ProcMLP.cc

Go to the documentation of this file.
00001 #include <algorithm>
00002 #include <iostream>
00003 #include <sstream>
00004 #include <fstream>
00005 #include <cstddef>
00006 #include <cstring>
00007 #include <vector>
00008 #include <memory>
00009 
00010 #include <xercesc/dom/DOM.hpp>
00011 
00012 #include "FWCore/Utilities/interface/Exception.h"
00013 
00014 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00015 
00016 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00017 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00018 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00019 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00020 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00021 
00022 XERCES_CPP_NAMESPACE_USE
00023 
00024 using namespace PhysicsTools;
00025 
00026 namespace { // anonymous
00027 
00028 class ProcMLP : public TrainProcessor {
00029     public:
00030         typedef TrainProcessor::Registry<ProcMLP>::Type Registry;
00031 
00032         ProcMLP(const char *name, const AtomicId *id,
00033                 MVATrainer *trainer);
00034         virtual ~ProcMLP();
00035 
00036         virtual void configure(DOMElement *elem);
00037         virtual Calibration::VarProcessor *getCalibration() const;
00038 
00039         virtual void trainBegin();
00040         virtual void trainData(const std::vector<double> *values,
00041                                bool target, double weight);
00042         virtual void trainEnd();
00043 
00044         virtual bool load();
00045         virtual void cleanup();
00046 
00047     private:
00048         enum Iteration {
00049                 ITER_TRAIN,
00050                 ITER_DONE
00051         } iteration;
00052 
00053         std::string             layout;
00054         unsigned int            steps;
00055         unsigned int            count, row;
00056         std::vector<double>     vars;
00057         std::vector<double>     targets;
00058         bool                    needCleanup;
00059 };
00060 
00061 static ProcMLP::Registry registry("ProcMLP");
00062 
00063 ProcMLP::ProcMLP(const char *name, const AtomicId *id,
00064                  MVATrainer *trainer) :
00065         TrainProcessor(name, id, trainer),
00066         iteration(ITER_TRAIN),
00067         count(0),
00068         needCleanup(false)
00069 {
00070 }
00071 
00072 ProcMLP::~ProcMLP()
00073 {
00074 }
00075 
00076 void ProcMLP::configure(DOMElement *elem)
00077 {
00078         std::vector<SourceVariable*> inputs = getInputs().get();
00079 
00080         DOMNode *node = elem->getFirstChild();
00081         while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
00082                 node = node->getNextSibling();
00083 
00084         if (!node)
00085                 throw cms::Exception("ProcMLP")
00086                         << "Expected MLP config in config section."
00087                         << std::endl;
00088 
00089         if (std::strcmp(XMLSimpleStr(node->getNodeName()), "config") != 0)
00090                 throw cms::Exception("ProcMLP")
00091                                 << "Expected config tag in config section."
00092                                 << std::endl;
00093 
00094         elem = static_cast<DOMElement*>(node);
00095 
00096         steps = XMLDocument::readAttribute<unsigned int>(elem, "steps");
00097 
00098         layout = (const char*)XMLSimpleStr(node->getTextContent());
00099 
00100         node = node->getNextSibling();
00101         while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
00102                 node = node->getNextSibling();
00103 
00104         if (node)
00105                 throw cms::Exception("ProcMLP")
00106                         << "Superfluous tags in config section."
00107                         << std::endl;
00108 
00109         vars.resize(getInputs().size());
00110         targets.resize(getOutputs().size());
00111 }
00112 
00113 bool ProcMLP::load()
00114 {
00115         bool ok = false;
00116         /* test for weights file */ {
00117                 std::ifstream in(trainer->trainFileName(this, "txt").c_str());
00118                 ok = in.good();
00119         }
00120 
00121         if (!ok)
00122                 return false;
00123 
00124         iteration = ITER_DONE;
00125         trained = true;
00126         return true;
00127 }
00128 
00129 Calibration::VarProcessor *ProcMLP::getCalibration() const
00130 {
00131         Calibration::ProcMLP *calib = new Calibration::ProcMLP;
00132 
00133         std::string fileName = trainer->trainFileName(this, "txt");
00134         std::ifstream in(fileName.c_str(), std::ios::binary | std::ios::in);
00135         if (!in.good())
00136                 throw cms::Exception("ProcMLP")
00137                         << "Weights file " << fileName
00138                         << "cannot be opened for reading." << std::endl;
00139 
00140         char linebuf[128];
00141         linebuf[127] = 0;
00142         in.getline(linebuf, 127);
00143         if (std::strncmp(linebuf, "# network structure ", 20) != 0)
00144                 throw cms::Exception("ProcMLP")
00145                         << "Weights file " << fileName
00146                         << "has invalid header." << std::endl;
00147 
00148         std::istringstream is(linebuf + 20);
00149         std::vector<unsigned int> layout;
00150         for(;;) {
00151                 unsigned int layer = 0;
00152                 is >> layer;
00153                 if (!layer)
00154                         break;
00155                 layout.push_back(layer);
00156         }
00157 
00158         if (layout.size() < 2 || layout.front() != getInputs().size()
00159             || layout.back() != 1)
00160                 throw cms::Exception("ProcMLP")
00161                         << "Weights file " << fileName
00162                         << "network layout does not match." << std::endl;
00163 
00164         in.getline(linebuf, 127);
00165 
00166         for(unsigned int layer = 1; layer < layout.size(); layer++) {
00167                 Calibration::ProcMLP::Layer layerConf;
00168 
00169                 for(unsigned int i = 0; i < layout[layer]; i++) {
00170                         Calibration::ProcMLP::Neuron neuron;
00171 
00172                         for(unsigned int j = 0; j <= layout[layer - 1]; j++) {
00173                                 in.getline(linebuf, 127);
00174                                 std::istringstream ss(linebuf);
00175                                 double weight;
00176                                 ss >> weight;
00177 
00178                                 if (j == 0)
00179                                         neuron.first = weight;
00180                                 else
00181                                         neuron.second.push_back(weight);
00182                         }
00183                         layerConf.first.push_back(neuron);
00184                 }
00185                 layerConf.second = layer < layout.size() - 1;
00186 
00187                 calib->layers.push_back(layerConf);
00188         }
00189 
00190         in.close();
00191 
00192         return calib;
00193 }
00194 
00195 void ProcMLP::trainBegin()
00196 {
00197         switch(iteration) {
00198             case ITER_TRAIN:
00199                 throw cms::Exception("ProcMLP")
00200                         << "Actual training for ProcMLP not provided"
00201                            "inside CMSSW. Please provide network weights"
00202                            "file in mlpfit format." << std::endl;
00203                 break;
00204             default:
00205                 /* shut up */;
00206         }
00207 }
00208 
00209 void ProcMLP::trainData(const std::vector<double> *values,
00210                         bool target, double weight)
00211 {
00212 }
00213 
00214 void ProcMLP::trainEnd()
00215 {
00216         switch(iteration) {
00217             default:
00218                 /* shut up */;
00219         }
00220 }
00221 
00222 void ProcMLP::cleanup()
00223 {
00224         if (!needCleanup)
00225                 return;
00226 
00227         std::remove(trainer->trainFileName(this, "txt").c_str());
00228 }
00229 
00230 } // anonymous namespace

Generated on Tue Jun 9 17:41:28 2009 for CMSSW by  doxygen 1.5.4