diff --git a/PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h b/PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h index f49ddcbd05b6a..f311961fd46cf 100644 --- a/PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h +++ b/PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h @@ -22,6 +22,11 @@ namespace cms::Ort { typedef std::vector> FloatArrays; + enum class Backend { + cpu, + cuda, + }; + class ONNXRuntime { public: ONNXRuntime(const std::string& model_path, const ::Ort::SessionOptions* session_options = nullptr); @@ -29,6 +34,8 @@ namespace cms::Ort { ONNXRuntime& operator=(const ONNXRuntime&) = delete; ~ONNXRuntime(); + static ::Ort::SessionOptions defaultSessionOptions(Backend backend = Backend::cpu); + // Run inference and get outputs // input_names: list of the names of the input nodes. // input_values: list of input arrays for each input node. The order of `input_values` must match `input_names`. diff --git a/PhysicsTools/ONNXRuntime/src/ONNXRuntime.cc b/PhysicsTools/ONNXRuntime/src/ONNXRuntime.cc index 1845d0cc64d82..130e4544585b9 100644 --- a/PhysicsTools/ONNXRuntime/src/ONNXRuntime.cc +++ b/PhysicsTools/ONNXRuntime/src/ONNXRuntime.cc @@ -27,9 +27,7 @@ namespace cms::Ort { if (session_options) { session_ = std::make_unique(env_, model_path.c_str(), *session_options); } else { - SessionOptions sess_opts; - sess_opts.SetIntraOpNumThreads(1); - session_ = std::make_unique(env_, model_path.c_str(), sess_opts); + session_ = std::make_unique(env_, model_path.c_str(), defaultSessionOptions()); } AllocatorWithDefaultOptions allocator; @@ -78,6 +76,17 @@ namespace cms::Ort { ONNXRuntime::~ONNXRuntime() {} + SessionOptions ONNXRuntime::defaultSessionOptions(Backend backend) { + SessionOptions sess_opts; + sess_opts.SetIntraOpNumThreads(1); + if (backend == Backend::cuda) { + // https://www.onnxruntime.ai/docs/reference/execution-providers/CUDA-ExecutionProvider.html + OrtCUDAProviderOptions options; + sess_opts.AppendExecutionProvider_CUDA(options); + } + return sess_opts; + } + FloatArrays ONNXRuntime::run(const std::vector& input_names, FloatArrays& input_values, const std::vector>& input_shapes, @@ -104,6 +113,10 @@ namespace cms::Ort { } else { input_dims = input_shapes[input_pos]; // rely on the given input_shapes to set the batch size + if (input_dims[0] != batch_size) { + throw cms::Exception("RuntimeError") << "The first element of `input_shapes` (" << input_dims[0] + << ") does not match the given `batch_size` (" << batch_size << ")"; + } } auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1, std::multiplies()); if (expected_len != (int64_t)value->size()) { diff --git a/PhysicsTools/ONNXRuntime/test/BuildFile.xml b/PhysicsTools/ONNXRuntime/test/BuildFile.xml index b8af87ffa32de..cb02c30d2f9c9 100644 --- a/PhysicsTools/ONNXRuntime/test/BuildFile.xml +++ b/PhysicsTools/ONNXRuntime/test/BuildFile.xml @@ -4,5 +4,6 @@ + diff --git a/PhysicsTools/ONNXRuntime/test/testONNXRuntime.cc b/PhysicsTools/ONNXRuntime/test/testONNXRuntime.cc index 5de1da9b9aa44..29ed226f6acc0 100644 --- a/PhysicsTools/ONNXRuntime/test/testONNXRuntime.cc +++ b/PhysicsTools/ONNXRuntime/test/testONNXRuntime.cc @@ -2,26 +2,32 @@ #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h" #include "FWCore/ParameterSet/interface/FileInPath.h" +#include "HeterogeneousCore/CUDAUtilities/interface/requireDevices.h" -#include #include using namespace cms::Ort; class testONNXRuntime : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(testONNXRuntime); - CPPUNIT_TEST(checkAll); + CPPUNIT_TEST(checkCPU); + CPPUNIT_TEST(checkGPU); CPPUNIT_TEST_SUITE_END(); +private: + void test(Backend backend); + public: - void checkAll(); + void checkCPU(); + void checkGPU(); }; CPPUNIT_TEST_SUITE_REGISTRATION(testONNXRuntime); -void testONNXRuntime::checkAll() { +void testONNXRuntime::test(Backend backend) { std::string model_path = edm::FileInPath("PhysicsTools/ONNXRuntime/test/data/model.onnx").fullPath(); - ONNXRuntime rt(model_path); + auto session_options = ONNXRuntime::defaultSessionOptions(backend); + ONNXRuntime rt(model_path, &session_options); for (const unsigned batch_size : {1, 2, 4}) { FloatArrays input_values{ std::vector(batch_size * 2, 1), @@ -35,3 +41,11 @@ void testONNXRuntime::checkAll() { } } } + +void testONNXRuntime::checkCPU() { test(Backend::cpu); } + +void testONNXRuntime::checkGPU() { + if (cms::cudatest::testDevices()) { + test(Backend::cuda); + } +}