23 sym_ = Symbol::Load(symbol_file);
31 std::map<std::string, NDArray> paramters;
32 NDArray::Load(param_file,
nullptr, ¶mters);
33 for (
const auto&
k : paramters) {
34 if (
k.first.substr(0, 4) ==
"aux:") {
35 auto name =
k.first.substr(4,
k.first.size() - 4);
38 if (
k.first.substr(0, 4) ==
"arg:") {
39 auto name =
k.first.substr(4,
k.first.size() - 4);
51 : sym_(
block.symbol()), arg_map_(
block.arg_map()), aux_map_(
block.aux_map()) {}
54 : sym_(
block.symbol(output_node)), arg_map_(
block.arg_map()), aux_map_(
block.aux_map()) {}
59 const std::vector<std::vector<mx_uint> >& input_shapes) {
84 exec_->Forward(
false);
98 const auto arg_name_list =
sym_.ListArguments();
99 std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
100 std::map<std::string, std::vector<mx_uint> > arg_shapes;
102 for (
const auto& arg_name : arg_name_list) {
103 auto iter =
arg_map_.find(arg_name);
105 arg_shapes[arg_name] = iter->second.GetShape();
108 sym_.InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);
111 std::vector<NDArray> arg_arrays;
112 for (
size_t i = 0;
i < in_shapes.size(); ++
i) {
113 const auto& shape = in_shapes[
i];
114 const auto& arg_name = arg_name_list[
i];
115 auto iter_arg =
arg_map_.find(arg_name);
117 arg_arrays.push_back(iter_arg->second);
119 arg_arrays.push_back(NDArray(shape,
context_,
false));
122 std::vector<NDArray> grad_arrays(arg_arrays.size());
123 std::vector<OpReqType> grad_reqs(arg_arrays.size(), kNullOp);
126 std::vector<NDArray> aux_arrays;
127 const auto aux_name_list =
sym_.ListAuxiliaryStates();
128 for (
size_t i = 0;
i < aux_shapes.size(); ++
i) {
129 const auto& shape = aux_shapes[
i];
130 const auto& aux_name = aux_name_list[
i];
131 auto iter_aux =
aux_map_.find(aux_name);
133 aux_arrays.push_back(iter_aux->second);
135 aux_arrays.push_back(NDArray(shape,
context_,
false));
140 exec_ = std::make_unique<Executor>(
sym_,
context_, arg_arrays, grad_arrays, grad_reqs, aux_arrays);