Skip to content

Commit

Permalink
[LLVM,TIR] Print LLVM intrinsic names instead of ids (apache#9964)
Browse files Browse the repository at this point in the history
* [LLVM,TIR] Print LLVM intrinsic names instead of ids

This makes it much easy to understand what is happening with llvm
intrinsics.

* add test, version llvm
  • Loading branch information
Tristan Konolige authored and ylc committed Feb 16, 2022
1 parent 5102e19 commit 435ed93
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 5 deletions.
16 changes: 16 additions & 0 deletions python/tvm/target/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,22 @@ def llvm_lookup_intrinsic_id(name):
return _ffi_api.llvm_lookup_intrinsic_id(name)


def llvm_get_intrinsic_name(intrin_id: int) -> str:
"""Get the name of an intrinsic for a given id.
Parameters
----------
intrin_id : int
The id of the intrinsic.
Returns
-------
name : str
The name of the intrinsic.
"""
return _ffi_api.llvm_get_intrinsic_name(intrin_id)


def llvm_version_major(allow_none=False):
"""Get the major LLVM version.
Expand Down
23 changes: 18 additions & 5 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,19 +392,32 @@ Doc TIRTextPrinter::VisitExpr_(const LetNode* op) {

Doc TIRTextPrinter::VisitExpr_(const CallNode* op) {
Doc doc;
std::vector<Doc> func_args;
if (auto* ptr_op = op->op.as<OpNode>()) {
doc << "@" << Doc::Text(ptr_op->name) << "(";
if (ptr_op->name == "tir.call_llvm_pure_intrin") {
auto f = tvm::runtime::Registry::Get("target.llvm_get_intrinsic_name");
ICHECK(f != nullptr)
<< "Cannot find target.llvm_get_intrinsic_name. Compile with USE_LLVM=On";
func_args.push_back(Print((*f)(Downcast<IntImm>(op->args[0])->value)));
for (size_t i = 1; i < op->args.size(); i++) {
func_args.push_back(Print(op->args[i]));
}
} else {
for (const auto& arg : op->args) {
func_args.push_back(Print(arg));
}
}
} else {
// TODO(bohan): Print out the name by he global var in the module.
auto* op_gvar = op->op.as<GlobalVarNode>();
ICHECK(op_gvar != nullptr);
doc << "@" << Doc::Text(op_gvar->name_hint) << "(";
for (const auto& arg : op->args) {
func_args.push_back(Print(arg));
}
}
std::vector<Doc> args;
for (const auto& arg : op->args) {
args.push_back(Print(arg));
}
doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) << ")";
doc << PrintSep(func_args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) << ")";
return doc;
}

Expand Down
15 changes: 15 additions & 0 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,21 @@ TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id")
return static_cast<int64_t>(llvm::Function::lookupIntrinsicID(name));
});

TVM_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t id) -> String {
#if TVM_LLVM_VERSION >= 130
return std::string(llvm::Intrinsic::getBaseName(static_cast<llvm::Intrinsic::ID>(id)));
#elif TVM_LLVM_VERSION >= 40
// This is the version of Intrinsic::getName that works for overloaded
// intrinsics. Helpfully, if we provide no types to this function, it
// will give us the overloaded name without the types appended. This
// should be enough information for most uses.
return std::string(llvm::Intrinsic::getName(static_cast<llvm::Intrinsic::ID>(id), {}));
#else
// Nothing to do, just return the intrinsic id number
return std::to_string(id);
#endif
});

TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int {
return TVM_LLVM_VERSION / 10;
});
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm import te
from tvm.relay.backend import Runtime
from tvm.contrib import utils, clang
from tvm.target.codegen import llvm_lookup_intrinsic_id, llvm_get_intrinsic_name
import tvm.script.tir as T
import numpy as np

Expand Down Expand Up @@ -57,6 +58,14 @@ def test_llvm_void_intrin():
fcode = tvm.build(mod, None, "llvm")


@tvm.testing.requires_llvm
def test_llvm_intrinsic_id():
orig_name = "llvm.x86.sse2.pmadd.wd"
intrin_id = llvm_lookup_intrinsic_id(orig_name)
name = llvm_get_intrinsic_name(intrin_id)
assert orig_name == name


@tvm.testing.requires_llvm
def test_llvm_overloaded_intrin():
# Name lookup for overloaded intrinsics in LLVM 4- requires a name
Expand Down

0 comments on commit 435ed93

Please sign in to comment.