Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CODEGEN] ARM Popcount lowering rule and codegen updates #1235

Merged
merged 5 commits into from
Jun 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions src/codegen/llvm/codegen_arm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,90 @@ class CodeGenARM final : public CodeGenCPU {
native_vector_bits_ = 16 * 8;
CodeGenCPU::InitTarget(tm);
}
llvm::Value* CreateIntrinsic(const Call* op) override;

private:
Expr ARMPopcount(const Call* op);
};

llvm::Value* CodeGenARM::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic("llvm_intrin")) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value);
if (id == ::llvm::Intrinsic::ctpop) {
Expr e = ARMPopcount(op);
return CodeGenCPU::CreateIntrinsic(e.as<Call>());
}
}
return CodeGenCPU::CreateIntrinsic(op);
}

Expr CodeGenARM::ARMPopcount(const Call *call) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need a regression test for this rule. please add a test case to arm popcount, to a new file tests/python/unittest/test_codegen_arm.py .

Since we don't have ARM device to verify, what we can do is to dump out the asm file(Maybe we can patch GetSource in llvm module to support get_source("asm") ) and verify the neons sequence is as expected.

using namespace ir;
const Expr& e = call->args[2];
::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;

// Fallback to default llvm lowering rule if input type not a full vector or half vector length
int total_size = call->type.bits() * call->type.lanes();
if (!call->type.is_vector() || call->type.bits() == 8 ||
(total_size != 128 && total_size != 64)) {
Array<Expr> vcnt_args;
vcnt_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
vcnt_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt_args.push_back(e);
return ir::Call::make(call->type, "llvm_intrin", vcnt_args, Call::PureIntrinsic);
}

// Popcount lowering rule:
// Reinterpret input vector as a vector of 8bit values and preform popcount
// Pairwise add between adjacent elements and double width with vpaddlu
// to return back to original input type

// Dvisions are always divisible (number of bits = 64 or 128)
Type uint8_type = Type(e.type().code(), 8, e.type().bits() * e.type().lanes() / 8);
Type uint16_type = Type(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16);
Type uint32_type = Type(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);

// Interpret input as vector of 8bit values
Expr input8 = reinterpret(uint8_type, e);
// Popcount 8bit->8bit
const Call* c0 = input8.as<Call>();
CHECK(c0 != nullptr);
Array<Expr> vcnt8_args;
vcnt8_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
vcnt8_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt8_args.push_back(input8);
Expr vcnt8 = ir::Call::make(uint8_type, "llvm_intrin", vcnt8_args, Call::PureIntrinsic);

// Accumulation 8->16bit
Array<Expr> vcnt16_args;
vcnt16_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt16_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt16_args.push_back(vcnt8);
Expr vcnt16 = ir::Call::make(uint16_type, "llvm_intrin", vcnt16_args, Call::PureIntrinsic);
if (call->type.bits() == 16) {
return vcnt16;
}

// Accumulation 16->32bit
Array<Expr> vcnt32_args;
vcnt32_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt32_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt32_args.push_back(vcnt16);
Expr vcnt32 = ir::Call::make(uint32_type, "llvm_intrin", vcnt32_args, Call::PureIntrinsic);
if (call->type.bits() == 32) {
return vcnt32;
}

// Accumulation 32->64bit
Array<Expr> vcnt64_args;
vcnt64_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt64_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt64_args.push_back(vcnt32);
return ir::Call::make(call->type, "llvm_intrin", vcnt64_args, Call::PureIntrinsic);
}

TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
CodeGenLLVM* cg = new CodeGenARM();
Expand Down
26 changes: 25 additions & 1 deletion src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
if (extent == num_elems && begin == 0) return vec;
CHECK_LT(begin + extent, num_elems);
CHECK_LE(begin + extent, num_elems);
std::vector<unsigned> indices;
for (int i = 0; i < extent; ++i) {
indices.push_back(begin + i);
Expand Down Expand Up @@ -562,6 +562,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
sig_type.push_back(arg_value.back()->getType());
}
}
llvm::Type *return_type = LLVMType(op->type);
if (sig_type.size() > 0 && return_type != sig_type[0]) {
sig_type.insert(sig_type.begin(), return_type);
}
llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(), id, sig_type);
return builder_->CreateCall(f, arg_value);
Expand Down Expand Up @@ -628,6 +632,26 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
value->addIncoming(then_value, then_value_block);
value->addIncoming(else_value, else_value_block);
return value;
} else if (op->is_intrinsic(Call::reinterpret)) {
llvm::Type * target = LLVMType(op->type);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
} else if (op->is_intrinsic("vectorlow")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
return CreateVecSlice(v, 0, l/2);
} else if (op->is_intrinsic("vectorhigh")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
return CreateVecSlice(v, l/2, l/2);
} else if (op->is_intrinsic("vectorcombine")) {
llvm::Value *v0 = MakeValue(op->args[0]);
llvm::Value *v1 = MakeValue(op->args[1]);
int num_elems = static_cast<int>(v0->getType()->getVectorNumElements()) * 2;
std::vector<unsigned> indices;
for (int i = 0; i < num_elems; ++i) {
indices.push_back(i);
}
return builder_->CreateShuffleVector(v0, v1, indices);
} else {
LOG(FATAL) << "unknown intrinsic " << op->name;
return nullptr;
Expand Down
38 changes: 34 additions & 4 deletions src/codegen/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,41 @@ class LLVMModuleNode final : public runtime::ModuleNode {
}

std::string GetSource(const std::string& format) final {
std::string fmt = runtime::GetFileFormat("", format);
std::string type_str;
llvm::raw_string_ostream rso(type_str);
CHECK(mptr_ != nullptr);
mptr_->print(rso, nullptr);
return rso.str();
llvm::SmallString<256> str;
llvm::raw_svector_ostream rso(str);

if (fmt == "s" || fmt == "asm") {
#if TVM_LLVM_VERSION <= 60
std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
#else
std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
#endif
llvm::legacy::PassManager pass;
CHECK(tm_);
#if TVM_LLVM_VERSION <= 60
CHECK(tm_->addPassesToEmitFile(
pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#else
CHECK(tm_->addPassesToEmitFile(
pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#endif
pass.run(*m);
return rso.str().str();
} else if (fmt == "" || fmt == "ll") {
std::string type_str;
llvm::raw_string_ostream rso(type_str);
CHECK(mptr_ != nullptr);
mptr_->print(rso, nullptr);
return rso.str();
} else {
LOG(FATAL) << "Do not know how to get source code with format: "
<< format << "\'";
}
return "";
}

void Init(const Array<LoweredFunc>& funcs, std::string target) {
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_codegen_arm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import tvm
import re
import os
import ctypes

def test_popcount():
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'

def check_correct_assembly(type, elements, counts):
n = tvm.convert(elements)
A = tvm.placeholder(n, dtype=type, name='A')
B = tvm.compute(A.shape, lambda i: tvm.popcount(A[i]), name='B')
s = tvm.create_schedule(B.op)
s[B].vectorize(s[B].op.axis[0])
f = tvm.build(s, [A, B], target)

# Verify we see the correct number of vpaddl and vcnt instructions in the assembly
assembly = f.get_source('asm')
matches = re.findall("vpaddl", assembly)
assert (len(matches) == counts)
matches = re.findall("vcnt", assembly)
assert (len(matches) == 1)
check_correct_assembly('uint16', 8, 1)
check_correct_assembly('uint16', 4, 1)
check_correct_assembly('uint32', 4, 2)
check_correct_assembly('uint32', 2, 2)
check_correct_assembly('uint64', 2, 3)

if __name__ == "__main__":
test_popcount()