Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable HLT tau filtering based on jet-tags (ParticleNet) #43100

Merged
merged 4 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion RecoTauTag/HLTProducers/src/L2TauTagFilter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class L2TauTagFilter : public HLTFilter {
}
for (size_t l1_idx = 0; l1_idx < l1Taus.size(); l1_idx++) {
if (L2Outcomes[l1_idx] >= discrWP_ || l1Taus[l1_idx]->pt() > l1PtTh_) {
filterproduct.addObject(nTauPassed, l1Taus[l1_idx]);
filterproduct.addObject(trigger::TriggerL1Tau, l1Taus[l1_idx]);
nTauPassed++;
}
}
Expand Down
171 changes: 171 additions & 0 deletions RecoTauTag/HLTProducers/src/TauTagFilter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* \class TauTagFilter
*
* Filter tau candidates based on tagger scores.
*
* \author Konstantin Androsov, EPFL and ETHZ
*/

#include "DataFormats/BTauReco/interface/JetTag.h"
#include "DataFormats/HLTReco/interface/TriggerTypeDefs.h"
#include "FWCore/Framework/interface/Event.h"
#include "FWCore/Utilities/interface/Exception.h"
#include "FWCore/Framework/interface/Frameworkfwd.h"
#include "FWCore/Utilities/interface/InputTag.h"
#include "FWCore/ParameterSet/interface/ParameterSet.h"
#include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
#include "HLTrigger/HLTcore/interface/HLTFilter.h"
#include "RecoTauTag/RecoTau/interface/TauWPThreshold.h"

class TauTagFilter : public HLTFilter {
public:
using TauCollection = reco::PFJetCollection;
using TauTagCollection = reco::JetTagCollection;
using TauRef = reco::PFJetRef;
using Selector = tau::TauWPThreshold;
using LorentzVectorM = math::PtEtaPhiMLorentzVector;

explicit TauTagFilter(const edm::ParameterSet& cfg)
: HLTFilter(cfg),
nExpected_(cfg.getParameter<int>("nExpected")),
tausSrc_(cfg.getParameter<edm::InputTag>("taus")),
tausToken_(consumes<TauCollection>(tausSrc_)),
tauTagsToken_(consumes<TauTagCollection>(cfg.getParameter<edm::InputTag>("tauTags"))),
tauPtCorrToken_(mayConsume<TauTagCollection>(cfg.getParameter<edm::InputTag>("tauPtCorr"))),
seedsSrc_(mayConsume<trigger::TriggerFilterObjectWithRefs>(cfg.getParameter<edm::InputTag>("seeds"))),
seedTypes_(cfg.getParameter<std::vector<int>>("seedTypes")),
selector_(cfg.getParameter<std::string>("selection")),
minPt_(cfg.getParameter<double>("minPt")),
maxEta_(cfg.getParameter<double>("maxEta")),
usePtCorr_(cfg.getParameter<bool>("usePtCorr")),
matchWithSeeds_(cfg.getParameter<bool>("matchWithSeeds") && cfg.getParameter<double>("matchingdR") >= 0),
matchingdR2_(std::pow(cfg.getParameter<double>("matchingdR"), 2)) {
if (cfg.getParameter<bool>("matchWithSeeds") && cfg.getParameter<double>("matchingdR") < 0)
edm::LogWarning("TauTagFilter") << "Matching with seeds is disabled because matchingdR < 0";

extractMomenta(); // checking that all seed types are supported
}

static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
edm::ParameterSetDescription desc;
makeHLTFilterDescription(desc);
desc.add<int>("nExpected", 2)->setComment("number of expected taus per event");
desc.add<edm::InputTag>("taus", edm::InputTag(""))->setComment("input collection of taus");
desc.add<edm::InputTag>("tauTags", edm::InputTag(""))->setComment("input collection of tau tagger scores");
desc.add<edm::InputTag>("tauPtCorr", edm::InputTag(""))
->setComment("input collection of multiplicative tau pt corrections");
desc.add<edm::InputTag>("seeds", edm::InputTag(""))->setComment("input collection of seeds");
desc.add<std::vector<int>>("seedTypes",
{trigger::TriggerL1Tau, trigger::TriggerL1Jet, trigger::TriggerTau, trigger::TriggerJet})
->setComment("list of seed object types");
desc.add<std::string>("selection", "0")->setComment("selection formula");
desc.add<double>("minPt", 20)->setComment("minimal tau pt");
desc.add<double>("maxEta", 2.5)->setComment("maximal tau abs(eta)");
desc.add<bool>("usePtCorr", false)->setComment("use multiplicative tau pt corrections");
desc.add<bool>("matchWithSeeds", false)->setComment("apply match with seeds");
desc.add<double>("matchingdR", 0.5)->setComment("deltaR for matching with seeds");
descriptions.addWithDefaultLabel(desc);
}

bool hltFilter(edm::Event& event,
const edm::EventSetup& eventsetup,
trigger::TriggerFilterObjectWithRefs& filterproduct) const override {
if (saveTags())
filterproduct.addCollectionTag(tausSrc_);

int nTauPassed = 0;

const auto tausHandle = event.getHandle(tausToken_);
const auto& taus = *tausHandle;

const std::vector<LorentzVectorM> seed_p4s = extractMomenta(&event);
auto hasMatch = [&](const LorentzVectorM& p4) {
for (const auto& seed_p4 : seed_p4s) {
if (reco::deltaR2(p4, seed_p4) < matchingdR2_)
return true;
}
return false;
};

const auto& tauTags = event.get(tauTagsToken_);
const TauTagCollection* tauPtCorr = nullptr;
if (usePtCorr_)
tauPtCorr = &event.get(tauPtCorrToken_);

if (taus.size() != tauTags.size())
throw cms::Exception("Inconsistent Data", "TauTagFilter::hltFilter") << "taus.size() != tauTags.size()";
if (usePtCorr_ && taus.size() != tauPtCorr->size())
throw cms::Exception("Inconsistent Data", "TauTagFilter::hltFilter") << "taus.size() != tauPtCorr.size()";

for (size_t tau_idx = 0; tau_idx < taus.size(); ++tau_idx) {
const auto& tau = taus[tau_idx];
double pt = tau.pt();
if (usePtCorr_)
pt *= (*tauPtCorr)[tau_idx].second;
const double eta = std::abs(tau.eta());
if (pt > minPt_ && eta < maxEta_ && (!matchWithSeeds_ || hasMatch(tau.polarP4()))) {
const double tag = tauTags[tau_idx].second;
const double tag_thr = selector_(tau);
if (tag > tag_thr) {
filterproduct.addObject(trigger::TriggerTau, TauRef(tausHandle, tau_idx));
nTauPassed++;
}
}
}

return nTauPassed >= nExpected_;
}

private:
std::vector<LorentzVectorM> extractMomenta(const edm::Event* event = nullptr) const {
std::vector<LorentzVectorM> seed_p4s;
if (matchWithSeeds_) {
const trigger::TriggerFilterObjectWithRefs* seeds = nullptr;
if (event)
seeds = &event->get(seedsSrc_);
for (const int seedType : seedTypes_) {
if (seedType == trigger::TriggerL1Tau) {
extractMomenta<l1t::TauVectorRef>(seeds, seedType, seed_p4s);
} else if (seedType == trigger::TriggerL1Jet) {
extractMomenta<l1t::JetVectorRef>(seeds, seedType, seed_p4s);
} else if (seedType == trigger::TriggerTau) {
extractMomenta<std::vector<reco::PFTauRef>>(seeds, seedType, seed_p4s);
} else if (seedType == trigger::TriggerJet) {
extractMomenta<std::vector<reco::PFJetRef>>(seeds, seedType, seed_p4s);
} else
throw cms::Exception("Invalid seed type", "TauTagFilter::extractMomenta")
<< "Unsupported seed type: " << seedType;
}
}
return seed_p4s;
}

template <typename Collection>
static void extractMomenta(const trigger::TriggerRefsCollections* triggerObjects,
int objType,
std::vector<LorentzVectorM>& p4s) {
if (triggerObjects) {
Collection objects;
triggerObjects->getObjects(objType, objects);
for (const auto& obj : objects)
p4s.push_back(obj->polarP4());
}
}

private:
const int nExpected_;
const edm::InputTag tausSrc_;
const edm::EDGetTokenT<TauCollection> tausToken_;
const edm::EDGetTokenT<TauTagCollection> tauTagsToken_, tauPtCorrToken_;
const edm::EDGetTokenT<trigger::TriggerFilterObjectWithRefs> seedsSrc_;
const std::vector<int> seedTypes_;
const Selector selector_;
const double minPt_, maxEta_;
const bool usePtCorr_;
const bool matchWithSeeds_;
const double matchingdR2_;
};

//define this as a plug-in
#include "FWCore/Framework/interface/MakerMacros.h"
DEFINE_FWK_MODULE(TauTagFilter);
64 changes: 64 additions & 0 deletions RecoTauTag/RecoTau/interface/TauWPThreshold.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#ifndef RecoTauTag_RecoTau_TauWPThreshold_h
#define RecoTauTag_RecoTau_TauWPThreshold_h

#include "DataFormats/TauReco/interface/BaseTau.h"
#include "DataFormats/PatCandidates/interface/Tau.h"
#include <TF1.h>

namespace tau {
class TauWPThreshold {
public:
explicit TauWPThreshold(const std::string& cut_str) {
bool simple_value = false;
try {
size_t pos = 0;
value_ = std::stod(cut_str, &pos);
simple_value = (pos == cut_str.size());
} catch (std::invalid_argument&) {
} catch (std::out_of_range&) {
}
Comment on lines +13 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This violates the CMSSW coding rules:

7.11 In general, do not catch exceptions – leave them to the Framework (see Exception Guidelines).

These changes should not have been accepted in the current state. Please fix this misuse ASAP.

Copy link
Contributor

@mmusich mmusich Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbluj can you have a look? This is in the way of further development of the HLT menu for 2024.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that try-catch can be simply removed as exceptions will be cached by framework?
@fwyzard for my curiosity: those exception catches was found because the exceptions have been thrown or it was done with other means?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that try-catch can be simply removed as exceptions will be cached by framework?

Exceptions caught by the framework will usually cause the jobs to terminate.

@fwyzard for my curiosity: those exception catches was found because the exceptions have been thrown or it was done with other means?

These exceptions have been noticed running cmsRun hlt.py under gdb and telling it to catch exceptions, in order to investigate were some failures are happening.
These try/catch blocks add a lot of noise to the debugging sessions, making it harder to figure out were the real problems are happening.

In general, in CMSSW it is discouraged to use try/catch as a way to check if a condition is valid or not; it's better to do an explicit check.

This approach is idiomatic in python, where checking for an exception is almost as fast as checking for a return value. In C++, the regular return values are faster, and the exceptions are slower.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbluj having a parsing error at this stage is ok; it just means that WP is not a "simple value". So I think your code should be modified like this:

const char* cut_cstr = cut_str.c_str();
char* end_cstr;
value_ = std::strtod(cut_cstr, &end_cstr);
simple_value = end_cstr != cut_cstr && std::isfinite(value_);

(without any throw)

Yes, indeed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, the code I proposed above is also wrong. For example, for cut_str="1 + 1", it will return simple_value=true and _value=1, which is not intended. So should be something like this:

    const char* cut_cstr = cut_str.c_str();
    char* end_cstr;
    value_ = std::strtod(cut_cstr, &end_cstr);
    const bool simple_value = end_cstr != cut_cstr && static_cast<std::size_t>(end_cstr - cut_cstr) == cut_str.size() && value_ != HUGE_VALF;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is enough to check if end_cstr is empty !*end_cstr, so the piece of code can be like this:

      const char* cut_cstr = cut_str.c_str();
      char* end_cstr{};
      value_ = std::strtod(cut_cstr, &end_cstr);
      simple_value = (!*end_cstr && cut_cstr != end_cstr && std::isfinite(value_));
      if (!simple_value) {

I think std::isfinite(value_) is better than value_ != HUGE_VAL as it does not accept nan. BTW, current code accepts both nan and inf as a valid "simple_value".

Copy link
Contributor

@kandrosov kandrosov Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Perhaps, as in your original proposal, an exception should be thrown for inf and especially nan - I don't see any use cases where one would want to use them in valid configs...
Is std::isfinite(HUGE_VALF) == false guaranteed, or could it be compiler-dependent? To make sure that we account for all possible return scenarios.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quote from cppreference:

On implementations that support floating-point infinities, these macros (HUGE_VALF, HUGE_VAL, HUGE_VALL) always expand to the positive infinities of float, double, and long double, respectively."

So, I think it is OK to check if value_ is finite.

In the current implementation not-finite numbers are treated as not being a "simple_value" and as they are not correct threshold expressions an exception is thrown.

BTW, for double one should use HUGE_VAL rather than HUGE_VALF that is for float.

if (!simple_value) {
static const std::string prefix =
"[&](double *x, double *p) { const int decayMode = p[0];"
"const double pt = p[1]; const double eta = p[2];";
static const int n_params = 3;
static const auto handler = [](int, Bool_t, const char*, const char*) -> void {};

std::string fn_str = prefix;
if (cut_str.find("return") == std::string::npos)
fn_str += " return " + cut_str + ";}";
else
fn_str += cut_str + "}";
auto old_handler = SetErrorHandler(handler);
fn_ = std::make_unique<TF1>("fn_", fn_str.c_str(), 0, 1, n_params);
SetErrorHandler(old_handler);
if (!fn_->IsValid())
throw cms::Exception("TauWPThreshold: invalid formula") << "Invalid WP cut formula = '" << cut_str << "'.";
}
}

double operator()(int dm, double pt, double eta) const {
if (!fn_)
return value_;

fn_->SetParameter(0, dm);
fn_->SetParameter(1, pt);
fn_->SetParameter(2, eta);
return fn_->Eval(0);
}

double operator()(const reco::BaseTau& tau, bool isPFTau) const {
const int dm =
isPFTau ? dynamic_cast<const reco::PFTau&>(tau).decayMode() : dynamic_cast<const pat::Tau&>(tau).decayMode();
return (*this)(dm, tau.pt(), tau.eta());
}

double operator()(const reco::Candidate& tau) const { return (*this)(-1, tau.pt(), tau.eta()); }

private:
std::unique_ptr<TF1> fn_;
double value_;
};
} // namespace tau

#endif
53 changes: 2 additions & 51 deletions RecoTauTag/RecoTau/plugins/DeepTauId.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
#include "DataFormats/PatCandidates/interface/PATTauDiscriminator.h"
#include "CommonTools/Utils/interface/StringObjectFunction.h"
#include "RecoTauTag/RecoTau/interface/PFRecoTauClusterVariables.h"
#include "RecoTauTag/RecoTau/interface/TauWPThreshold.h"
#include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
#include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
#include "DataFormats/Common/interface/View.h"
#include "DataFormats/Common/interface/RefToBase.h"
#include "DataFormats/Provenance/interface/ProductProvenance.h"
#include "DataFormats/Provenance/interface/ProcessHistoryID.h"
#include "FWCore/Common/interface/Provenance.h"
#include <TF1.h>
#include <map>
#include "RecoTauTag/RecoTau/interface/DeepTauScaling.h"
#include "FWCore/Utilities/interface/isFinite.h"
Expand All @@ -47,55 +47,6 @@ namespace deep_tau {
PUcorrPtSum
};

class TauWPThreshold {
public:
explicit TauWPThreshold(const std::string& cut_str) {
bool simple_value = false;
try {
size_t pos = 0;
value_ = std::stod(cut_str, &pos);
simple_value = (pos == cut_str.size());
} catch (std::invalid_argument&) {
} catch (std::out_of_range&) {
}
if (!simple_value) {
static const std::string prefix =
"[&](double *x, double *p) { const int decayMode = p[0];"
"const double pt = p[1]; const double eta = p[2];";
static const int n_params = 3;
static const auto handler = [](int, Bool_t, const char*, const char*) -> void {};

std::string fn_str = prefix;
if (cut_str.find("return") == std::string::npos)
fn_str += " return " + cut_str + ";}";
else
fn_str += cut_str + "}";
auto old_handler = SetErrorHandler(handler);
fn_ = std::make_unique<TF1>("fn_", fn_str.c_str(), 0, 1, n_params);
SetErrorHandler(old_handler);
if (!fn_->IsValid())
throw cms::Exception("TauWPThreshold: invalid formula") << "Invalid WP cut formula = '" << cut_str << "'.";
}
}
double operator()(const reco::BaseTau& tau, bool isPFTau) const {
if (!fn_) {
return value_;
}

if (isPFTau)
fn_->SetParameter(0, dynamic_cast<const reco::PFTau&>(tau).decayMode());
else
fn_->SetParameter(0, dynamic_cast<const pat::Tau&>(tau).decayMode());
fn_->SetParameter(1, tau.pt());
fn_->SetParameter(2, tau.eta());
return fn_->Eval(0);
}

private:
std::unique_ptr<TF1> fn_;
double value_;
};

class DeepTauCache {
public:
using GraphPtr = std::shared_ptr<tensorflow::GraphDef>;
Expand Down Expand Up @@ -951,7 +902,7 @@ class DeepTauId : public edm::stream::EDProducer<edm::GlobalCache<deep_tau::Deep
using ElectronCollection = pat::ElectronCollection;
using MuonCollection = pat::MuonCollection;
using LorentzVectorXYZ = ROOT::Math::LorentzVector<ROOT::Math::PxPyPzE4D<double>>;
using Cutter = deep_tau::TauWPThreshold;
using Cutter = tau::TauWPThreshold;
using CutterPtr = std::unique_ptr<Cutter>;
using WPList = std::vector<CutterPtr>;

Expand Down