Skip to content

Commit

Permalink
Add a test unit for MXNetPredictor.
Browse files Browse the repository at this point in the history
  • Loading branch information
hqucms committed Sep 11, 2018
1 parent d0e4e1a commit 794b953
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 0 deletions.
8 changes: 8 additions & 0 deletions PhysicsTools/MXNet/test/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<bin name="testMXNetPredictor" file="testRunner.cpp, testMXNetPredictor.cc">
<use name="boost_filesystem" />
<use name="cppunit" />

<use name="PhysicsTools/MXNet" />
<use name="FWCore/ParameterSet" />
<use name="FWCore/Utilities" />
</bin>
Binary file added PhysicsTools/MXNet/test/data/testmxnet-0000.params
Binary file not shown.
79 changes: 79 additions & 0 deletions PhysicsTools/MXNet/test/data/testmxnet-symbol.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
{
"nodes": [
{
"op": "null",
"name": "data",
"inputs": []
},
{
"op": "null",
"name": "dense0_weight",
"attrs": {
"__dtype__": "0",
"__lr_mult__": "1.0",
"__shape__": "(7, 0)",
"__wd_mult__": "1.0"
},
"inputs": []
},
{
"op": "null",
"name": "dense0_bias",
"attrs": {
"__dtype__": "0",
"__init__": "zeros",
"__lr_mult__": "1.0",
"__shape__": "(7,)",
"__wd_mult__": "1.0"
},
"inputs": []
},
{
"op": "FullyConnected",
"name": "dense0_fwd",
"attrs": {
"flatten": "True",
"no_bias": "False",
"num_hidden": "7"
},
"inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]]
},
{
"op": "null",
"name": "dense1_weight",
"attrs": {
"__dtype__": "0",
"__lr_mult__": "1.0",
"__shape__": "(3, 0)",
"__wd_mult__": "1.0"
},
"inputs": []
},
{
"op": "null",
"name": "dense1_bias",
"attrs": {
"__dtype__": "0",
"__init__": "zeros",
"__lr_mult__": "1.0",
"__shape__": "(3,)",
"__wd_mult__": "1.0"
},
"inputs": []
},
{
"op": "FullyConnected",
"name": "dense1_fwd",
"attrs": {
"flatten": "True",
"no_bias": "False",
"num_hidden": "3"
},
"inputs": [[3, 0, 0], [4, 0, 0], [5, 0, 0]]
}
],
"arg_nodes": [0, 1, 2, 4, 5],
"node_row_ptr": [0, 1, 2, 3, 4, 5, 6, 7],
"heads": [[6, 0, 0]],
"attrs": {"mxnet_version": ["int", 10200]}
}
58 changes: 58 additions & 0 deletions PhysicsTools/MXNet/test/testMXNetPredictor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <cppunit/extensions/HelperMacros.h>

#include "FWCore/ParameterSet/interface/FileInPath.h"
#include "PhysicsTools/MXNet/interface/MXNetPredictor.h"

class testMXNetPredictor : public CppUnit::TestFixture
{
CPPUNIT_TEST_SUITE(testMXNetPredictor);
CPPUNIT_TEST(checkAll);
CPPUNIT_TEST_SUITE_END();

public:
void checkAll();
};

CPPUNIT_TEST_SUITE_REGISTRATION(testMXNetPredictor);

void testMXNetPredictor::checkAll()
{

// load model and params into BufferFile
std::string model_path = edm::FileInPath("PhysicsTools/MXNet/test/data/testmxnet-symbol.json").fullPath();
std::string param_path = edm::FileInPath("PhysicsTools/MXNet/test/data/testmxnet-0000.params").fullPath();

mxnet::BufferFile *model_file = new mxnet::BufferFile(model_path);
CPPUNIT_ASSERT(model_file != nullptr);
CPPUNIT_ASSERT(model_file->GetLength() > 0);

mxnet::BufferFile *param_file = new mxnet::BufferFile(param_path);
CPPUNIT_ASSERT(param_file != nullptr);
CPPUNIT_ASSERT(param_file->GetLength() > 0);

// create predictor
mxnet::MXNetPredictor predictor;

// set input shape
std::vector<std::string> input_names {"data"};
std::vector<std::vector<unsigned>> input_shapes {{1, 3}};
CPPUNIT_ASSERT_NO_THROW( predictor.set_input_shapes(input_names, input_shapes) );

// load model from BufferFile
CPPUNIT_ASSERT_NO_THROW( predictor.load_model(model_file, param_file) );

// run predictor
std::vector<std::vector<float>> data {{1, 2, 3,}};
std::vector<float> outputs;
CPPUNIT_ASSERT_NO_THROW( outputs = predictor.predict(data) );

// check outputs
CPPUNIT_ASSERT(outputs.size() == 3);
CPPUNIT_ASSERT(outputs.at(0) == 42);
CPPUNIT_ASSERT(outputs.at(1) == 42);
CPPUNIT_ASSERT(outputs.at(2) == 42);

delete model_file;
delete param_file;

}
1 change: 1 addition & 0 deletions PhysicsTools/MXNet/test/testRunner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include <Utilities/Testing/interface/CppUnit_testdriver.icpp>

0 comments on commit 794b953

Please sign in to comment.