From 2fcd0fb67cf4e51d86daa9fae6ba54a083812ce3 Mon Sep 17 00:00:00 2001 From: Huilin Qu Date: Fri, 31 Aug 2018 18:21:13 +0200 Subject: [PATCH] Rename MXNet Predictor. --- .../{MXNetCppPredictor.h => Predictor.h} | 10 +++---- .../{MXNetCppPredictor.cc => Predictor.cc} | 26 ++++++++++--------- ...tMXNetCppPredictor.cc => testPredictor.cc} | 4 +-- .../plugins/DeepBoostedJetTagsProducer.cc | 7 +++-- 4 files changed, 24 insertions(+), 23 deletions(-) rename PhysicsTools/MXNet/interface/{MXNetCppPredictor.h => Predictor.h} (91%) rename PhysicsTools/MXNet/src/{MXNetCppPredictor.cc => Predictor.cc} (78%) rename PhysicsTools/MXNet/test/{testMXNetCppPredictor.cc => testPredictor.cc} (93%) diff --git a/PhysicsTools/MXNet/interface/MXNetCppPredictor.h b/PhysicsTools/MXNet/interface/Predictor.h similarity index 91% rename from PhysicsTools/MXNet/interface/MXNetCppPredictor.h rename to PhysicsTools/MXNet/interface/Predictor.h index ed97731ac1574..693cbf4926d0f 100644 --- a/PhysicsTools/MXNet/interface/MXNetCppPredictor.h +++ b/PhysicsTools/MXNet/interface/Predictor.h @@ -47,12 +47,12 @@ class Block { // Simple helper class to run prediction // this cannot be shared between threads -class MXNetCppPredictor { +class Predictor { public: - MXNetCppPredictor(); - MXNetCppPredictor(const Block &block); - MXNetCppPredictor(const Block &block, const std::string &output_node); - virtual ~MXNetCppPredictor(); + Predictor(); + Predictor(const Block &block); + Predictor(const Block &block, const std::string &output_node); + virtual ~Predictor(); // set input array shapes void set_input_shapes(const std::vector& input_names, const std::vector>& input_shapes); diff --git a/PhysicsTools/MXNet/src/MXNetCppPredictor.cc b/PhysicsTools/MXNet/src/Predictor.cc similarity index 78% rename from PhysicsTools/MXNet/src/MXNetCppPredictor.cc rename to PhysicsTools/MXNet/src/Predictor.cc index 2ed1bb6b72be6..f23e0e630dd7c 100644 --- a/PhysicsTools/MXNet/src/MXNetCppPredictor.cc +++ b/PhysicsTools/MXNet/src/Predictor.cc @@ -5,10 +5,11 @@ * Author: hqu */ +#include "PhysicsTools/MXNet/interface/Predictor.h" + #include #include "FWCore/Utilities/interface/Exception.h" -#include "PhysicsTools/MXNet/interface/MXNetCppPredictor.h" namespace mxnet { @@ -42,33 +43,34 @@ void Block::load_parameters(const std::string& param_file) { } } -std::mutex MXNetCppPredictor::mutex_; -const Context MXNetCppPredictor::context_ = Context(DeviceType::kCPU, 0); +std::mutex Predictor::mutex_; +const Context Predictor::context_ = Context(DeviceType::kCPU, 0); -MXNetCppPredictor::MXNetCppPredictor() { +Predictor::Predictor() { } -MXNetCppPredictor::MXNetCppPredictor(const Block& block) : sym_(block.symbol()), arg_map_(block.arg_map()), aux_map_(block.aux_map()) { +Predictor::Predictor(const Block& block) +: sym_(block.symbol()), arg_map_(block.arg_map()), aux_map_(block.aux_map()) { } -MXNetCppPredictor::MXNetCppPredictor(const Block &block, const std::string &output_node) : sym_(block.symbol(output_node)), arg_map_(block.arg_map()), aux_map_(block.aux_map()) { +Predictor::Predictor(const Block &block, const std::string &output_node) +: sym_(block.symbol(output_node)), arg_map_(block.arg_map()), aux_map_(block.aux_map()) { } -MXNetCppPredictor::~MXNetCppPredictor() { +Predictor::~Predictor() { } -void MXNetCppPredictor::set_input_shapes(const std::vector& input_names, const std::vector >& input_shapes) { +void Predictor::set_input_shapes(const std::vector& input_names, const std::vector >& input_shapes) { assert(input_names.size() == input_shapes.size()); input_names_ = input_names; // init the input NDArrays and add them to the arg_map for (unsigned i=0; i& MXNetCppPredictor::predict(const std::vector >& input_data) { +const std::vector& Predictor::predict(const std::vector >& input_data) { assert(input_names_.size() == input_data.size()); try { @@ -90,7 +92,7 @@ const std::vector& MXNetCppPredictor::predict(const std::vector lock(mutex_); diff --git a/PhysicsTools/MXNet/test/testMXNetCppPredictor.cc b/PhysicsTools/MXNet/test/testPredictor.cc similarity index 93% rename from PhysicsTools/MXNet/test/testMXNetCppPredictor.cc rename to PhysicsTools/MXNet/test/testPredictor.cc index 8d691f3f9c2b6..e50922c9eb9f2 100644 --- a/PhysicsTools/MXNet/test/testMXNetCppPredictor.cc +++ b/PhysicsTools/MXNet/test/testPredictor.cc @@ -1,7 +1,7 @@ #include +#include "PhysicsTools/MXNet/interface/Predictor.h" #include "FWCore/ParameterSet/interface/FileInPath.h" -#include "PhysicsTools/MXNet/interface/MXNetCppPredictor.h" using namespace mxnet::cpp; @@ -29,7 +29,7 @@ void testMXNetCppPredictor::checkAll() CPPUNIT_ASSERT(block!=nullptr); // create predictor - MXNetCppPredictor predictor(*block); + Predictor predictor(*block); // set input shape std::vector input_names {"data"}; diff --git a/RecoBTag/DeepBoostedJet/plugins/DeepBoostedJetTagsProducer.cc b/RecoBTag/DeepBoostedJet/plugins/DeepBoostedJetTagsProducer.cc index c6d37d5adeab6..f7afdaa5dfe6d 100644 --- a/RecoBTag/DeepBoostedJet/plugins/DeepBoostedJetTagsProducer.cc +++ b/RecoBTag/DeepBoostedJet/plugins/DeepBoostedJetTagsProducer.cc @@ -13,10 +13,9 @@ #include "DataFormats/BTauReco/interface/DeepBoostedJetTagInfo.h" -#include "PhysicsTools/MXNet/interface/MXNetCppPredictor.h" - #include #include +#include "PhysicsTools/MXNet/interface/Predictor.h" // Hold the mxnet model block (symbol + params) in the edm::GlobalCache. struct MXBlockCache { @@ -82,7 +81,7 @@ class DeepBoostedJetTagsProducer : public edm::stream::EDProducer prep_info_map_; // preprocessing info for each input group std::vector> data_; - std::unique_ptr predictor_; + std::unique_ptr predictor_; bool debug_ = false; }; @@ -128,7 +127,7 @@ DeepBoostedJetTagsProducer::DeepBoostedJetTagsProducer(const edm::ParameterSet& } // init MXNetPredictor - predictor_.reset(new mxnet::cpp::MXNetCppPredictor(*cache->block)); + predictor_.reset(new mxnet::cpp::Predictor(*cache->block)); predictor_->set_input_shapes(input_names_, input_shapes_); // get output names from flav_table