Skip to content

Commit

Permalink
Refactor MXNetCppPredictor.
Browse files Browse the repository at this point in the history
Re-bind executor every time for thread-safety.
  • Loading branch information
hqucms committed Sep 11, 2018
1 parent 830e1c3 commit 0966b24
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
14 changes: 12 additions & 2 deletions PhysicsTools/MXNet/interface/MXNetCppPredictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Block {
virtual ~Block();

const Symbol& symbol() const { return sym_; }
Symbol symbol(const std::string &output_node) const { return sym_.GetInternals()[output_node]; }
const std::map<std::string, NDArray>& arg_map() const { return arg_map_; }
const std::map<std::string, NDArray>& aux_map() const { return aux_map_; }

Expand All @@ -50,16 +51,18 @@ class MXNetCppPredictor {
public:
MXNetCppPredictor();
MXNetCppPredictor(const Block &block);
MXNetCppPredictor(const Block &block, const std::string &output_node);
virtual ~MXNetCppPredictor();

void set_input_shapes(const std::vector<std::string>& input_names, const std::vector<std::vector<mx_uint>>& input_shapes);
void set_output_node_name(const std::string& output_node_name);
const std::vector<float>& predict(const std::vector<std::vector<mx_float>>& input_data);

private:
void bind_executor();
static std::mutex mutex_;

void infer_shapes();
void bind_executor();

// context
static const Context context_;
// executor
Expand All @@ -74,6 +77,13 @@ class MXNetCppPredictor {
std::vector<float> pred_;
// names of the input nodes
std::vector<std::string> input_names_;

// internal states
std::vector<NDArray> arg_arrays;
std::vector<NDArray> grad_arrays;
std::vector<OpReqType> grad_reqs;
std::vector<NDArray> aux_arrays;

};

} /* namespace cpp */
Expand Down
30 changes: 15 additions & 15 deletions PhysicsTools/MXNet/src/MXNetCppPredictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ MXNetCppPredictor::MXNetCppPredictor() {
MXNetCppPredictor::MXNetCppPredictor(const Block& block) : sym_(block.symbol()), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {
}

MXNetCppPredictor::MXNetCppPredictor(const Block &block, const std::string &output_node) : sym_(block.symbol(output_node)), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {
}

MXNetCppPredictor::~MXNetCppPredictor() {
}

Expand All @@ -63,21 +66,16 @@ void MXNetCppPredictor::set_input_shapes(const std::vector<std::string>& input_n
NDArray nd(input_shapes.at(i), context_, false);
arg_map_[name] = nd;
}
}

void MXNetCppPredictor::set_output_node_name(const std::string& output_node_name) {
if (!output_node_name.empty()){
sym_ = sym_.GetInternals()[output_node_name];
}
// infer parameter shapes from input shapes
infer_shapes();
}

const std::vector<float>& MXNetCppPredictor::predict(const std::vector<std::vector<mx_float> >& input_data) {
assert(input_names_.size() == input_data.size());

try {
// create the executor (if not done yet)
if (!exec_) { bind_executor(); }
assert(exec_);
// bind executor
bind_executor();
// set the inputs
for (unsigned i=0; i<input_names_.size(); ++i){
const auto& name = input_names_.at(i);
Expand All @@ -93,7 +91,7 @@ const std::vector<float>& MXNetCppPredictor::predict(const std::vector<std::vect
}
}

void MXNetCppPredictor::bind_executor() {
void MXNetCppPredictor::infer_shapes() {
// acquire lock
std::lock_guard<std::mutex> lock(mutex_);

Expand All @@ -111,7 +109,7 @@ void MXNetCppPredictor::bind_executor() {
sym_.InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);

// init argument arrays
std::vector<NDArray> arg_arrays;
arg_arrays.clear();
for (size_t i = 0; i < in_shapes.size(); ++i) {
const auto &shape = in_shapes[i];
const auto &arg_name = arg_name_list[i];
Expand All @@ -122,11 +120,11 @@ void MXNetCppPredictor::bind_executor() {
arg_arrays.push_back(NDArray(shape, context_, false));
}
}
std::vector<NDArray> grad_arrays(arg_arrays.size());
std::vector<OpReqType> grad_reqs(arg_arrays.size(), kNullOp);
grad_arrays = std::vector<NDArray>(arg_arrays.size());
grad_reqs = std::vector<OpReqType>(arg_arrays.size(), kNullOp);

// init auxiliary array
std::vector<NDArray> aux_arrays;
aux_arrays.clear();
const auto aux_name_list = sym_.ListAuxiliaryStates();
for (size_t i = 0; i < aux_shapes.size(); ++i) {
const auto &shape = aux_shapes[i];
Expand All @@ -139,9 +137,11 @@ void MXNetCppPredictor::bind_executor() {
}
}

}

void MXNetCppPredictor::bind_executor() {
// bind executor
exec_.reset(new Executor(sym_, context_, arg_arrays, grad_arrays, grad_reqs, aux_arrays));

}

}
Expand Down

0 comments on commit 0966b24

Please sign in to comment.