23 sym_ = Symbol::Load(symbol_file);
32 std::map<std::string, NDArray> paramters;
33 NDArray::Load(param_file,
nullptr, ¶mters);
34 for (
const auto &
k : paramters) {
35 if (
k.first.substr(0, 4) ==
"aux:") {
36 auto name =
k.first.substr(4,
k.first.size() - 4);
39 if (
k.first.substr(0, 4) ==
"arg:") {
40 auto name =
k.first.substr(4,
k.first.size() - 4);
64 assert(input_names.size() == input_shapes.size());
67 for (
unsigned i=0;
i<input_names_.size(); ++
i){
68 const auto&
name = input_names_[
i];
73 const std::vector<float>&
Predictor::predict(
const std::vector<std::vector<mx_float> >& input_data) {
86 exec_->Forward(
false);
90 }
catch(
const dmlc::Error &
e){
91 throw cms::Exception(
"RuntimeError") << e.what() << MXGetLastError();
100 const auto arg_name_list =
sym_.ListArguments();
101 std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
102 std::map<std::string, std::vector<mx_uint> > arg_shapes;
104 for (
const auto &arg_name : arg_name_list) {
105 auto iter =
arg_map_.find(arg_name);
107 arg_shapes[arg_name] = iter->second.GetShape();
110 sym_.InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);
113 std::vector<NDArray> arg_arrays;
114 for (
size_t i = 0;
i < in_shapes.size(); ++
i) {
115 const auto &shape = in_shapes[
i];
116 const auto &arg_name = arg_name_list[
i];
117 auto iter_arg =
arg_map_.find(arg_name);
119 arg_arrays.push_back(iter_arg->second);
121 arg_arrays.push_back(NDArray(shape,
context_,
false));
124 std::vector<NDArray> grad_arrays(arg_arrays.size());
125 std::vector<OpReqType> grad_reqs(arg_arrays.size(), kNullOp);
128 std::vector<NDArray> aux_arrays;
129 const auto aux_name_list =
sym_.ListAuxiliaryStates();
130 for (
size_t i = 0;
i < aux_shapes.size(); ++
i) {
131 const auto &shape = aux_shapes[
i];
132 const auto &aux_name = aux_name_list[
i];
133 auto iter_aux =
aux_map_.find(aux_name);
135 aux_arrays.push_back(iter_aux->second);
137 aux_arrays.push_back(NDArray(shape,
context_,
false));
142 exec_.reset(
new Executor(
sym_,
context_, arg_arrays, grad_arrays, grad_reqs, aux_arrays));
const std::map< std::string, NDArray > & aux_map() const
std::map< std::string, NDArray > aux_map_
static boost::mutex mutex
std::map< std::string, NDArray > arg_map_
void load_parameters(const std::string ¶m_file)
std::map< std::string, NDArray > aux_map_
std::vector< float > pred_
const std::map< std::string, NDArray > & arg_map() const
const Symbol & symbol() const
static const Context context_
std::map< std::string, NDArray > arg_map_
std::unique_ptr< Executor > exec_
const std::vector< float > & predict(const std::vector< std::vector< mx_float >> &input_data)
void set_input_shapes(const std::vector< std::string > &input_names, const std::vector< std::vector< mx_uint >> &input_shapes)
std::vector< std::string > input_names_