00001
00002
00003
00004
00005
00006 #include "PhysicsTools/StatPatternRecognition/interface/SprExperiment.hh"
00007 #include "PhysicsTools/StatPatternRecognition/interface/SprTrainedRBF.hh"
00008 #include "PhysicsTools/StatPatternRecognition/interface/SprData.hh"
00009 #include "PhysicsTools/StatPatternRecognition/interface/SprEmptyFilter.hh"
00010 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsReader.hh"
00011 #include "PhysicsTools/StatPatternRecognition/interface/SprAbsWriter.hh"
00012 #include "PhysicsTools/StatPatternRecognition/interface/SprDataFeeder.hh"
00013 #include "PhysicsTools/StatPatternRecognition/interface/SprRWFactory.hh"
00014
00015 #include <unistd.h>
00016 #include <stdio.h>
00017 #include <iostream>
00018 #include <vector>
00019 #include <string>
00020 #include <memory>
00021
00022 using namespace std;
00023
00024
00025 void help(const char* prog)
00026 {
00027 cout << "Usage: " << prog
00028 << " training_data_file "
00029 << " net_configuration_file " << endl;
00030 cout << "\t Options: " << endl;
00031 cout << "\t-h --- help " << endl;
00032 cout << "\t-o output Tuple file " << endl;
00033 cout << "\t-a input ascii file mode (see SprSimpleReader.hh) " << endl;
00034 cout << "\t-A save output data in ascii instead of Root " << endl;
00035 }
00036
00037
00038 int main(int argc, char ** argv)
00039 {
00040
00041 if( argc < 2 ) {
00042 help(argv[0]);
00043 return 1;
00044 }
00045
00046
00047 string tupleFile;
00048 int readMode = 0;
00049 SprRWFactory::DataType writeMode = SprRWFactory::Root;
00050
00051
00052 int c;
00053 extern char* optarg;
00054
00055 while( (c = getopt(argc,argv,"ho:a:A")) != EOF ) {
00056 switch( c )
00057 {
00058 case 'h' :
00059 help(argv[0]);
00060 return 1;
00061 case 'o' :
00062 tupleFile = optarg;
00063 break;
00064 case 'a' :
00065 readMode = (optarg==0 ? 0 : atoi(optarg));
00066 break;
00067 case 'A' :
00068 writeMode = SprRWFactory::Ascii;
00069 break;
00070 }
00071 }
00072
00073
00074 string trFile = argv[argc-2];
00075 string netFile = argv[argc-1];
00076 if( trFile.empty() ) {
00077 cerr << "No training file is specified." << endl;
00078 return 1;
00079 }
00080 if( netFile.empty() ) {
00081 cerr << "No net file is specified." << endl;
00082 return 1;
00083 }
00084
00085
00086 SprRWFactory::DataType inputType
00087 = ( readMode==0 ? SprRWFactory::Root : SprRWFactory::Ascii );
00088 auto_ptr<SprAbsReader> reader(SprRWFactory::makeReader(inputType,readMode));
00089 auto_ptr<SprAbsFilter> filter(reader->read(trFile.c_str()));
00090 if( filter.get() == 0 ) {
00091 cerr << "Unable to read data from file " << trFile.c_str() << endl;
00092 return 2;
00093 }
00094 vector<string> vars;
00095 filter->vars(vars);
00096 cout << "Read data from file " << trFile.c_str()
00097 << " for variables";
00098 for( int i=0;i<vars.size();i++ )
00099 cout << " \"" << vars[i].c_str() << "\"";
00100 cout << endl;
00101 cout << "Total number of points read: " << filter->size() << endl;
00102 cout << "Points in class 0: " << filter->ptsInClass(0)
00103 << " 1: " << filter->ptsInClass(1) << endl;
00104
00105
00106 SprTrainedRBF net;
00107 if( !net.readNet(netFile.c_str()) ) {
00108 cerr << "Unable to read net file " << netFile.c_str() << endl;
00109 return 3;
00110 }
00111 else {
00112 cout << "Read net configuration file:" << endl;
00113 net.print(cout);
00114 }
00115
00116
00117 if( tupleFile.empty() ) return 0;
00118
00119
00120 auto_ptr<SprAbsWriter> tuple(SprRWFactory::makeWriter(writeMode,"training"));
00121 if( !tuple->init(tupleFile.c_str()) ) {
00122 cerr << "Unable to open output file " << tupleFile.c_str() << endl;
00123 return 4;
00124 }
00125
00126
00127 SprDataFeeder feeder(filter.get(),tuple.get());
00128 feeder.addClassifier(&net,"rbf");
00129 if( !feeder.feed(1000) ) {
00130 cerr << "Cannot feed data into file " << tupleFile.c_str() << endl;
00131 return 5;
00132 }
00133
00134
00135 return 0;
00136 }