Skip to content

Commit

Permalink
Rename MXNet Predictor.
Browse files Browse the repository at this point in the history
  • Loading branch information
hqucms committed Sep 11, 2018
1 parent 3aae32e commit 2fcd0fb
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ class Block {

// Simple helper class to run prediction
// this cannot be shared between threads
class MXNetCppPredictor {
class Predictor {
public:
MXNetCppPredictor();
MXNetCppPredictor(const Block &block);
MXNetCppPredictor(const Block &block, const std::string &output_node);
virtual ~MXNetCppPredictor();
Predictor();
Predictor(const Block &block);
Predictor(const Block &block, const std::string &output_node);
virtual ~Predictor();

// set input array shapes
void set_input_shapes(const std::vector<std::string>& input_names, const std::vector<std::vector<mx_uint>>& input_shapes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
* Author: hqu
*/

#include "PhysicsTools/MXNet/interface/Predictor.h"

#include <cassert>
#include "FWCore/Utilities/interface/Exception.h"

#include "PhysicsTools/MXNet/interface/MXNetCppPredictor.h"

namespace mxnet {

Expand Down Expand Up @@ -42,33 +43,34 @@ void Block::load_parameters(const std::string& param_file) {
}
}

std::mutex MXNetCppPredictor::mutex_;
const Context MXNetCppPredictor::context_ = Context(DeviceType::kCPU, 0);
std::mutex Predictor::mutex_;
const Context Predictor::context_ = Context(DeviceType::kCPU, 0);

MXNetCppPredictor::MXNetCppPredictor() {
Predictor::Predictor() {
}

MXNetCppPredictor::MXNetCppPredictor(const Block& block) : sym_(block.symbol()), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {
Predictor::Predictor(const Block& block)
: sym_(block.symbol()), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {
}

MXNetCppPredictor::MXNetCppPredictor(const Block &block, const std::string &output_node) : sym_(block.symbol(output_node)), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {
Predictor::Predictor(const Block &block, const std::string &output_node)
: sym_(block.symbol(output_node)), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {
}

MXNetCppPredictor::~MXNetCppPredictor() {
Predictor::~Predictor() {
}

void MXNetCppPredictor::set_input_shapes(const std::vector<std::string>& input_names, const std::vector<std::vector<mx_uint> >& input_shapes) {
void Predictor::set_input_shapes(const std::vector<std::string>& input_names, const std::vector<std::vector<mx_uint> >& input_shapes) {
assert(input_names.size() == input_shapes.size());
input_names_ = input_names;
// init the input NDArrays and add them to the arg_map
for (unsigned i=0; i<input_names_.size(); ++i){
const auto& name = input_names_[i];
NDArray nd(input_shapes[i], context_, false);
arg_map_[name] = nd;
arg_map_.emplace(name, NDArray(input_shapes[i], context_, false));
}
}

const std::vector<float>& MXNetCppPredictor::predict(const std::vector<std::vector<mx_float> >& input_data) {
const std::vector<float>& Predictor::predict(const std::vector<std::vector<mx_float> >& input_data) {
assert(input_names_.size() == input_data.size());

try {
Expand All @@ -90,7 +92,7 @@ const std::vector<float>& MXNetCppPredictor::predict(const std::vector<std::vect
}
}

void MXNetCppPredictor::bind_executor() {
void Predictor::bind_executor() {
// acquire lock
std::lock_guard<std::mutex> lock(mutex_);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <cppunit/extensions/HelperMacros.h>

#include "PhysicsTools/MXNet/interface/Predictor.h"
#include "FWCore/ParameterSet/interface/FileInPath.h"
#include "PhysicsTools/MXNet/interface/MXNetCppPredictor.h"

using namespace mxnet::cpp;

Expand Down Expand Up @@ -29,7 +29,7 @@ void testMXNetCppPredictor::checkAll()
CPPUNIT_ASSERT(block!=nullptr);

// create predictor
MXNetCppPredictor predictor(*block);
Predictor predictor(*block);

// set input shape
std::vector<std::string> input_names {"data"};
Expand Down
7 changes: 3 additions & 4 deletions RecoBTag/DeepBoostedJet/plugins/DeepBoostedJetTagsProducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@

#include "DataFormats/BTauReco/interface/DeepBoostedJetTagInfo.h"

#include "PhysicsTools/MXNet/interface/MXNetCppPredictor.h"

#include <iostream>
#include <fstream>
#include "PhysicsTools/MXNet/interface/Predictor.h"

// Hold the mxnet model block (symbol + params) in the edm::GlobalCache.
struct MXBlockCache {
Expand Down Expand Up @@ -82,7 +81,7 @@ class DeepBoostedJetTagsProducer : public edm::stream::EDProducer<edm::GlobalCac
std::unordered_map<std::string, PreprocessParams> prep_info_map_; // preprocessing info for each input group

std::vector<std::vector<float>> data_;
std::unique_ptr<mxnet::cpp::MXNetCppPredictor> predictor_;
std::unique_ptr<mxnet::cpp::Predictor> predictor_;

bool debug_ = false;
};
Expand Down Expand Up @@ -128,7 +127,7 @@ DeepBoostedJetTagsProducer::DeepBoostedJetTagsProducer(const edm::ParameterSet&
}

// init MXNetPredictor
predictor_.reset(new mxnet::cpp::MXNetCppPredictor(*cache->block));
predictor_.reset(new mxnet::cpp::Predictor(*cache->block));
predictor_->set_input_shapes(input_names_, input_shapes_);

// get output names from flav_table
Expand Down

0 comments on commit 2fcd0fb

Please sign in to comment.