00001 #include <algorithm>
00002 #include <iostream>
00003 #include <sstream>
00004 #include <fstream>
00005 #include <cstddef>
00006 #include <cstring>
00007 #include <vector>
00008 #include <memory>
00009
00010 #include <xercesc/dom/DOM.hpp>
00011
00012 #include "FWCore/Utilities/interface/Exception.h"
00013
00014 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00015
00016 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00017 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00018 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00019 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00020 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00021
00022 XERCES_CPP_NAMESPACE_USE
00023
00024 using namespace PhysicsTools;
00025
00026 namespace {
00027
00028 class ProcMLP : public TrainProcessor {
00029 public:
00030 typedef TrainProcessor::Registry<ProcMLP>::Type Registry;
00031
00032 ProcMLP(const char *name, const AtomicId *id,
00033 MVATrainer *trainer);
00034 virtual ~ProcMLP();
00035
00036 virtual void configure(DOMElement *elem);
00037 virtual Calibration::VarProcessor *getCalibration() const;
00038
00039 virtual void trainBegin();
00040 virtual void trainData(const std::vector<double> *values,
00041 bool target, double weight);
00042 virtual void trainEnd();
00043
00044 virtual bool load();
00045 virtual void cleanup();
00046
00047 private:
00048 enum Iteration {
00049 ITER_TRAIN,
00050 ITER_DONE
00051 } iteration;
00052
00053 std::string layout;
00054 unsigned int steps;
00055 unsigned int count, row;
00056 std::vector<double> vars;
00057 std::vector<double> targets;
00058 bool needCleanup;
00059 };
00060
00061 static ProcMLP::Registry registry("ProcMLP");
00062
00063 ProcMLP::ProcMLP(const char *name, const AtomicId *id,
00064 MVATrainer *trainer) :
00065 TrainProcessor(name, id, trainer),
00066 iteration(ITER_TRAIN),
00067 count(0),
00068 needCleanup(false)
00069 {
00070 }
00071
00072 ProcMLP::~ProcMLP()
00073 {
00074 }
00075
00076 void ProcMLP::configure(DOMElement *elem)
00077 {
00078 std::vector<SourceVariable*> inputs = getInputs().get();
00079
00080 DOMNode *node = elem->getFirstChild();
00081 while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
00082 node = node->getNextSibling();
00083
00084 if (!node)
00085 throw cms::Exception("ProcMLP")
00086 << "Expected MLP config in config section."
00087 << std::endl;
00088
00089 if (std::strcmp(XMLSimpleStr(node->getNodeName()), "config") != 0)
00090 throw cms::Exception("ProcMLP")
00091 << "Expected config tag in config section."
00092 << std::endl;
00093
00094 elem = static_cast<DOMElement*>(node);
00095
00096 steps = XMLDocument::readAttribute<unsigned int>(elem, "steps");
00097
00098 layout = (const char*)XMLSimpleStr(node->getTextContent());
00099
00100 node = node->getNextSibling();
00101 while(node && node->getNodeType() != DOMNode::ELEMENT_NODE)
00102 node = node->getNextSibling();
00103
00104 if (node)
00105 throw cms::Exception("ProcMLP")
00106 << "Superfluous tags in config section."
00107 << std::endl;
00108
00109 vars.resize(getInputs().size());
00110 targets.resize(getOutputs().size());
00111 }
00112
00113 bool ProcMLP::load()
00114 {
00115 bool ok = false;
00116 {
00117 std::ifstream in(trainer->trainFileName(this, "txt").c_str());
00118 ok = in.good();
00119 }
00120
00121 if (!ok)
00122 return false;
00123
00124 iteration = ITER_DONE;
00125 trained = true;
00126 return true;
00127 }
00128
00129 Calibration::VarProcessor *ProcMLP::getCalibration() const
00130 {
00131 Calibration::ProcMLP *calib = new Calibration::ProcMLP;
00132
00133 std::string fileName = trainer->trainFileName(this, "txt");
00134 std::ifstream in(fileName.c_str(), std::ios::binary | std::ios::in);
00135 if (!in.good())
00136 throw cms::Exception("ProcMLP")
00137 << "Weights file " << fileName
00138 << "cannot be opened for reading." << std::endl;
00139
00140 char linebuf[128];
00141 linebuf[127] = 0;
00142 in.getline(linebuf, 127);
00143 if (std::strncmp(linebuf, "# network structure ", 20) != 0)
00144 throw cms::Exception("ProcMLP")
00145 << "Weights file " << fileName
00146 << "has invalid header." << std::endl;
00147
00148 std::istringstream is(linebuf + 20);
00149 std::vector<unsigned int> layout;
00150 for(;;) {
00151 unsigned int layer = 0;
00152 is >> layer;
00153 if (!layer)
00154 break;
00155 layout.push_back(layer);
00156 }
00157
00158 if (layout.size() < 2 || layout.front() != getInputs().size()
00159 || layout.back() != 1)
00160 throw cms::Exception("ProcMLP")
00161 << "Weights file " << fileName
00162 << "network layout does not match." << std::endl;
00163
00164 in.getline(linebuf, 127);
00165
00166 for(unsigned int layer = 1; layer < layout.size(); layer++) {
00167 Calibration::ProcMLP::Layer layerConf;
00168
00169 for(unsigned int i = 0; i < layout[layer]; i++) {
00170 Calibration::ProcMLP::Neuron neuron;
00171
00172 for(unsigned int j = 0; j <= layout[layer - 1]; j++) {
00173 in.getline(linebuf, 127);
00174 std::istringstream ss(linebuf);
00175 double weight;
00176 ss >> weight;
00177
00178 if (j == 0)
00179 neuron.first = weight;
00180 else
00181 neuron.second.push_back(weight);
00182 }
00183 layerConf.first.push_back(neuron);
00184 }
00185 layerConf.second = layer < layout.size() - 1;
00186
00187 calib->layers.push_back(layerConf);
00188 }
00189
00190 in.close();
00191
00192 return calib;
00193 }
00194
00195 void ProcMLP::trainBegin()
00196 {
00197 switch(iteration) {
00198 case ITER_TRAIN:
00199 throw cms::Exception("ProcMLP")
00200 << "Actual training for ProcMLP not provided"
00201 "inside CMSSW. Please provide network weights"
00202 "file in mlpfit format." << std::endl;
00203 break;
00204 default:
00205 ;
00206 }
00207 }
00208
00209 void ProcMLP::trainData(const std::vector<double> *values,
00210 bool target, double weight)
00211 {
00212 }
00213
00214 void ProcMLP::trainEnd()
00215 {
00216 switch(iteration) {
00217 default:
00218 ;
00219 }
00220 }
00221
00222 void ProcMLP::cleanup()
00223 {
00224 if (!needCleanup)
00225 return;
00226
00227 std::remove(trainer->trainFileName(this, "txt").c_str());
00228 }
00229
00230 }