21 sym_ = Symbol::Load(symbol_file);
29 std::map<std::string, NDArray> paramters;
30 NDArray::Load(param_file,
nullptr, ¶mters);
31 for (
const auto&
k : paramters) {
32 if (
k.first.substr(0, 4) ==
"aux:") {
33 auto name =
k.first.substr(4,
k.first.size() - 4);
36 if (
k.first.substr(0, 4) ==
"arg:") {
37 auto name =
k.first.substr(4,
k.first.size() - 4);
49 : sym_(
block.symbol()), arg_map_(
block.arg_map()), aux_map_(
block.aux_map()) {}
52 : sym_(
block.symbol(output_node)), arg_map_(
block.arg_map()), aux_map_(
block.aux_map()) {}
57 const std::vector<std::vector<mx_uint> >& input_shapes) {
67 const std::vector<float>&
Predictor::predict(
const std::vector<std::vector<mx_float> >& input_data) {
82 exec_->Forward(
false);
96 const auto arg_name_list =
sym_.ListArguments();
97 std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
98 std::map<std::string, std::vector<mx_uint> > arg_shapes;
100 for (
const auto& arg_name : arg_name_list) {
101 auto iter =
arg_map_.find(arg_name);
103 arg_shapes[arg_name] = iter->second.GetShape();
106 sym_.InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);
109 std::vector<NDArray> arg_arrays;
110 for (
size_t i = 0;
i < in_shapes.size(); ++
i) {
111 const auto& shape = in_shapes[
i];
112 const auto& arg_name = arg_name_list[
i];
113 auto iter_arg =
arg_map_.find(arg_name);
115 arg_arrays.push_back(iter_arg->second);
117 arg_arrays.push_back(NDArray(shape,
context_,
false));
120 std::vector<NDArray> grad_arrays(arg_arrays.size());
121 std::vector<OpReqType> grad_reqs(arg_arrays.size(), kNullOp);
124 std::vector<NDArray> aux_arrays;
125 const auto aux_name_list =
sym_.ListAuxiliaryStates();
126 for (
size_t i = 0;
i < aux_shapes.size(); ++
i) {
127 const auto& shape = aux_shapes[
i];
128 const auto& aux_name = aux_name_list[
i];
129 auto iter_aux =
aux_map_.find(aux_name);
131 aux_arrays.push_back(iter_aux->second);
133 aux_arrays.push_back(NDArray(shape,
context_,
false));
138 exec_.reset(
new Executor(
sym_,
context_, arg_arrays, grad_arrays, grad_reqs, aux_arrays));