Skip to content

Commit

Permalink
Add working Onnxruntime Producer (#254)
Browse files Browse the repository at this point in the history
* restructure cmake files into multiple files

* add onnxruntime

* format cmake files

* add onnxruntime to templates

* add generic onnxruntime producer

* add ml namespace to docs
  • Loading branch information
harrypuuter committed May 8, 2024
1 parent e870781 commit 78bd176
Show file tree
Hide file tree
Showing 9 changed files with 423 additions and 3 deletions.
5 changes: 5 additions & 0 deletions code_generation/analysis_template.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "include/jets.hxx"
#include "include/lorentzvectors.hxx"
#include "include/met.hxx"
#include "include/ml.hxx"
#include "include/utility/OnnxSessionManager.hxx"
#include "include/metfilter.hxx"
#include "include/pairselection.hxx"
#include "include/tripleselection.hxx"
Expand Down Expand Up @@ -104,6 +106,9 @@ int main(int argc, char *argv[]) {
// file logging
Logger::enableFileLogging("logs/main.txt");

// start an onnx session manager
OnnxSessionManager onnxSessionManager;

// {MULTITHREADING}

// initialize df
Expand Down
5 changes: 5 additions & 0 deletions code_generation/analysis_template_friends.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "include/jets.hxx"
#include "include/lorentzvectors.hxx"
#include "include/met.hxx"
#include "include/ml.hxx"
#include "include/utility/OnnxSessionManager.hxx"
#include "include/metfilter.hxx"
#include "include/pairselection.hxx"
#include "include/physicsobjects.hxx"
Expand Down Expand Up @@ -131,6 +133,9 @@ int main(int argc, char *argv[]) {
// file logging
Logger::enableFileLogging("logs/main.txt");

// start an onnx session manager
OnnxSessionManager onnxSessionManager;

// {MULTITHREADING}

// build a tchain from input file with all friends
Expand Down
6 changes: 4 additions & 2 deletions code_generation/code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def write(self):
log.debug("folder: {}, file_name: {}".format(self.folder, self.file_name))
# write the header file if it does not exist or is different
with open(self.headerfile + ".new", "w") as f:
f.write(f"ROOT::RDF::RNode {self.name}(ROOT::RDF::RNode df);")
f.write(
f"ROOT::RDF::RNode {self.name}(ROOT::RDF::RNode df, OnnxSessionManager &onnxSessionManager);"
)
if os.path.isfile(self.headerfile):
if filecmp.cmp(self.headerfile + ".new", self.headerfile):
log.debug("--> Identical header file, skipping")
Expand Down Expand Up @@ -144,7 +146,7 @@ def call(self, inputscope: str, outputscope: str) -> str:
Returns:
str: the call to the code subset
"""
call = f" auto {outputscope} = {self.name}({inputscope}); \n"
call = f" auto {outputscope} = {self.name}({inputscope}, onnxSessionManager); \n"
return call

def include(self) -> str:
Expand Down
4 changes: 3 additions & 1 deletion code_generation/subset_template.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "include/jets.hxx"
#include "include/lorentzvectors.hxx"
#include "include/met.hxx"
#include "include/ml.hxx"
#include "include/utility/OnnxSessionManager.hxx"
#include "include/metfilter.hxx"
#include "include/pairselection.hxx"
#include "include/tripleselection.hxx"
Expand All @@ -25,7 +27,7 @@
#include <TTree.h>
#include <regex>
#include <string>
ROOT::RDF::RNode {subsetname} (ROOT::RDF::RNode df0) {
ROOT::RDF::RNode {subsetname} (ROOT::RDF::RNode df0, OnnxSessionManager &onnxSessionManager) {

// { commands }
}
6 changes: 6 additions & 0 deletions docs/sphinx_source/c_namespaces/ml.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Namespace: Basefunctions
========================
.. doxygennamespace:: ml
:members:
:undoc-members:
:private-members:
98 changes: 98 additions & 0 deletions include/ml.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#ifndef GUARD_ML_H
#define GUARD_ML_H

#include "../include/utility/OnnxSessionManager.hxx"
#include "TMVA/RModel.hxx"
#include "TMVA/RModelParser_ONNX.hxx"
#include "utility/utility.hxx"
#include <cstddef>

namespace ml {

ROOT::RDF::RNode StandardTransformer(ROOT::RDF::RNode df,
const std::string &inputname,
const std::string &outputname,
const std::string &paramfile,
const std::string &var_type);

/// Generic Function to evaluate an ONNX model using the ONNX Runtime
/// Due to unknowns reasons, this function must be implemented inline in the
/// header file, otherwise the linker will complain about undefined references.
/// Moving the implementation to the source file will result in a linker error.
/// Why, I don't know...
/// This generic implementation currenty supports only NNs with one input tensor
/// and one output tensor
///
/// \param df the dataframe to add the quantity to
/// \param OnnxSessionManager The OnnxSessionManager object to handle the
/// runtime session. By default this is called onnxSessionManager and created in
/// the main function
/// \param outputname Name of the output column
/// \param model_file_path Path to the ONNX model file
/// \param input_vec Vector of input variable names,
/// the order of the variables must match the order of the input
/// nodes in the ONNX model
///
/// \returns a dataframe with the filter applied

template <std::size_t nParameters>
inline ROOT::RDF::RNode GenericOnnxEvaluator(
ROOT::RDF::RNode df, OnnxSessionManager &onnxSessionManager,
const std::string &outputname, const std::string &model_file_path,
const std::vector<std::string> &input_vec) {

std::vector<std::string> InputList;
for (auto i = 0; i < input_vec.size(); i++) {
InputList.push_back(std::string(input_vec[i]));
}

// print content of InputList
for (auto i = 0; i < InputList.size(); ++i) {
Logger::get("OnnxEvaluate")
->debug("input: {} ( {} / {} )", InputList[i], i + 1, nParameters);
}

if (nParameters != InputList.size()) {
Logger::get("OnnxEvaluate")
->error("Number of input parameters does not match the number of "
"input variables: {} vs {}",
nParameters, InputList.size());
throw std::runtime_error("Number of input parameters does not match");
}

// Load the model and create InferenceSession
std::vector<int64_t> input_node_dims;
std::vector<int64_t> output_node_dims;
int num_input_nodes;
int num_output_nodes;
Ort::AllocatorWithDefaultOptions allocator;

auto session = onnxSessionManager.getSession(model_file_path);

onnxhelper::prepare_model(session, allocator, input_node_dims,
output_node_dims, num_input_nodes,
num_output_nodes);

auto NNEvaluator = [session, allocator, input_node_dims, output_node_dims,
num_input_nodes,
num_output_nodes](std::vector<float> inputs) {
TStopwatch timer;
timer.Start();

std::vector<float> output = onnxhelper::run_interference(
session, allocator, inputs, input_node_dims, output_node_dims,
num_input_nodes, num_output_nodes);

timer.Stop();
Logger::get("OnnxEvaluate")
->debug("Inference time: {} mus", timer.RealTime() * 1000 * 1000);
return output;
};
auto df1 = df.Define(outputname,
utility::PassAsVec<nParameters, float>(NNEvaluator),
InputList);

return df1;
}
} // end namespace ml
#endif /* GUARD_ML_H */
57 changes: 57 additions & 0 deletions include/utility/OnnxSessionManager.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifndef GUARD_SESSION_MANAGER
#define GUARD_SESSION_MANAGER

#include "Logger.hxx"
#include <memory>
#include <onnxruntime_cxx_api.h>
#include <string>
#include <unordered_map>

class OnnxSessionManager {
public:
OnnxSessionManager() {
OrtLoggingLevel logging_level =
ORT_LOGGING_LEVEL_WARNING; // ORT_LOGGING_LEVEL_VERBOSE

env = Ort::Env(logging_level, "Default");
session_options.SetInterOpNumThreads(1);
session_options.SetIntraOpNumThreads(1);
};
Ort::Session *getSession(const std::string &modelPath) {
// check if session already exists in the sessions map
if (sessions_map.count(modelPath) == 0) {
sessions_map[modelPath] = std::make_unique<Ort::Session>(
env, modelPath.c_str(), session_options);
Logger::get("OnnxSessionManager")
->info("Created session for model: {}", modelPath);
} else {
Logger::get("OnnxSessionManager")
->info("Session already exists for model: {}", modelPath);
}
return sessions_map[modelPath].get();
};

private:
std::unordered_map<std::string, std::unique_ptr<Ort::Session>> sessions_map;
Ort::Env env;
Ort::SessionOptions session_options;
};

namespace onnxhelper {
void prepare_model(Ort::Session *session,
Ort::AllocatorWithDefaultOptions allocator,
std::vector<int64_t> &input_node_dims,
std::vector<int64_t> &output_node_dims, int &num_input_nodes,
int &num_output_nodes);

std::vector<float> run_interference(Ort::Session *session,
Ort::AllocatorWithDefaultOptions allocator,
std::vector<float> &evt_input,
std::vector<int64_t> input_node_dims,
std::vector<int64_t> output_node_dims,
const int num_input_nodes,
const int num_output_nodes);

} // namespace onnxhelper

#endif /* GUARD_SESSION_MANAGER */
123 changes: 123 additions & 0 deletions src/ml.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#ifndef GUARD_ML_H
#define GUARD_ML_H

#include "../include/basefunctions.hxx"
#include "../include/defaults.hxx"
#include "../include/utility/Logger.hxx"
#include "../include/utility/utility.hxx"
#include "../include/ml.hxx"
#include "../include/utility/OnnxSessionManager.hxx"
#include "../include/vectoroperations.hxx"
#include "ROOT/RDataFrame.hxx"
#include "ROOT/RVec.hxx"
#include <Math/Vector4D.h>
#include <Math/VectorUtil.h>

#include "TMVA/RModel.hxx"
#include "TInterpreter.h"
#include "TMVA/RModelParser_ONNX.hxx"
#include "TSystem.h"
#include <memory>
#include <onnxruntime_cxx_api.h>
#include <assert.h>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <nlohmann/json.hpp>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>

using json = nlohmann::json;


namespace ml {
/**
* @brief Function to perform a standard transformation of input variables for
* NN evaluation.
*
* @param df The input dataframe
* @param inputname name of the variable which should be transformed
* @param outputname name of the output column
* @param param_file path to a json file with a dictionary of mean and std
* values
* @param var_type variable data type for correct processing e.g. "i" for
* integer or "f" for float
* @return a new dataframe containing the new column
*/
ROOT::RDF::RNode StandardTransformer(ROOT::RDF::RNode df,
const std::string &inputname,
const std::string &outputname,
const std::string &param_file,
const std::string &var_type) {
// read params from file
Logger::get("StandardTransformer")->debug("reading file {}", param_file);
std::string replace_str = std::string("EVTID");
std::string odd_file_path =
std::string(param_file)
.replace(param_file.find(replace_str), replace_str.length(),
std::string("odd"));
std::string even_file_path =
std::string(param_file)
.replace(param_file.find(replace_str), replace_str.length(),
std::string("even"));

std::ifstream odd_file(odd_file_path);
json odd_info = json::parse(odd_file);
std::ifstream even_file(even_file_path);
json even_info = json::parse(even_file);

// odd or even files mean that they are trained on odd or even events, so it
// has to be applied on the opposite
auto transform_int = [odd_info, even_info,
inputname](const unsigned long long event_id,
const int input_var) {
float shifted = -10;
if (int(event_id) % 2 == 0) {
shifted = (float(input_var) - float(odd_info[inputname]["mean"])) /
float(odd_info[inputname]["std"]);
} else if (int(event_id) % 2 == 1) {
shifted = (float(input_var) - float(even_info[inputname]["mean"])) /
float(even_info[inputname]["std"]);
}
Logger::get("StandardTransformer")
->debug("transforming var {} from {} to {}", inputname, input_var,
shifted);
return shifted;
};
auto transform_float = [odd_info, even_info,
inputname](const unsigned long long event_id,
const float input_var) {
float shifted = -10;
if (int(event_id) % 2 == 0) {
shifted = (float(input_var) - float(odd_info[inputname]["mean"])) /
float(odd_info[inputname]["std"]);
} else if (int(event_id) % 2 == 1) {
shifted = (float(input_var) - float(even_info[inputname]["mean"])) /
float(even_info[inputname]["std"]);
}
Logger::get("StandardTransformer")
->debug("transforming var {} from {} to {}", inputname, input_var,
shifted);
return shifted;
};

const std::string event_id = std::string("event");
if (var_type.rfind("i", 0) == 0) {
auto df1 = df.Define(outputname, transform_int, {event_id, inputname});
return df1;
} else if (var_type.rfind("f", 0) == 0) {
auto df1 =
df.Define(outputname, transform_float, {event_id, inputname});
return df1;
} else {
Logger::get("StandardTransformer")
->debug("transformation failed due to wrong variable type: {}",
var_type);
return df;
}
}

} // end namespace ml
#endif /* GUARD_ML_H */
Loading

0 comments on commit 78bd176

Please sign in to comment.