Skip to content

Commit

Permalink
[CODEGEN] ARM Popcount lowering rule and codegen updates (apache#1235)
Browse files Browse the repository at this point in the history
  • Loading branch information
Meghan Cowan authored and sergei-mironov committed Aug 8, 2018
1 parent bd8cb9b commit 53d65da
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 5 deletions.
82 changes: 82 additions & 0 deletions src/codegen/llvm/codegen_arm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,90 @@ class CodeGenARM final : public CodeGenCPU {
native_vector_bits_ = 16 * 8;
CodeGenCPU::InitTarget(tm);
}
llvm::Value* CreateIntrinsic(const Call* op) override;

private:
Expr ARMPopcount(const Call* op);
};

llvm::Value* CodeGenARM::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic("llvm_intrin")) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value);
if (id == ::llvm::Intrinsic::ctpop) {
Expr e = ARMPopcount(op);
return CodeGenCPU::CreateIntrinsic(e.as<Call>());
}
}
return CodeGenCPU::CreateIntrinsic(op);
}

Expr CodeGenARM::ARMPopcount(const Call *call) {
using namespace ir;
const Expr& e = call->args[2];
::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;

// Fallback to default llvm lowering rule if input type not a full vector or half vector length
int total_size = call->type.bits() * call->type.lanes();
if (!call->type.is_vector() || call->type.bits() == 8 ||
(total_size != 128 && total_size != 64)) {
Array<Expr> vcnt_args;
vcnt_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
vcnt_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt_args.push_back(e);
return ir::Call::make(call->type, "llvm_intrin", vcnt_args, Call::PureIntrinsic);
}

// Popcount lowering rule:
// Reinterpret input vector as a vector of 8bit values and preform popcount
// Pairwise add between adjacent elements and double width with vpaddlu
// to return back to original input type

// Dvisions are always divisible (number of bits = 64 or 128)
Type uint8_type = Type(e.type().code(), 8, e.type().bits() * e.type().lanes() / 8);
Type uint16_type = Type(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16);
Type uint32_type = Type(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);

// Interpret input as vector of 8bit values
Expr input8 = reinterpret(uint8_type, e);
// Popcount 8bit->8bit
const Call* c0 = input8.as<Call>();
CHECK(c0 != nullptr);
Array<Expr> vcnt8_args;
vcnt8_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
vcnt8_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt8_args.push_back(input8);
Expr vcnt8 = ir::Call::make(uint8_type, "llvm_intrin", vcnt8_args, Call::PureIntrinsic);

// Accumulation 8->16bit
Array<Expr> vcnt16_args;
vcnt16_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt16_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt16_args.push_back(vcnt8);
Expr vcnt16 = ir::Call::make(uint16_type, "llvm_intrin", vcnt16_args, Call::PureIntrinsic);
if (call->type.bits() == 16) {
return vcnt16;
}

// Accumulation 16->32bit
Array<Expr> vcnt32_args;
vcnt32_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt32_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt32_args.push_back(vcnt16);
Expr vcnt32 = ir::Call::make(uint32_type, "llvm_intrin", vcnt32_args, Call::PureIntrinsic);
if (call->type.bits() == 32) {
return vcnt32;
}

// Accumulation 32->64bit
Array<Expr> vcnt64_args;
vcnt64_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt64_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt64_args.push_back(vcnt32);
return ir::Call::make(call->type, "llvm_intrin", vcnt64_args, Call::PureIntrinsic);
}

TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
CodeGenLLVM* cg = new CodeGenARM();
Expand Down
26 changes: 25 additions & 1 deletion src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
if (extent == num_elems && begin == 0) return vec;
CHECK_LT(begin + extent, num_elems);
CHECK_LE(begin + extent, num_elems);
std::vector<unsigned> indices;
for (int i = 0; i < extent; ++i) {
indices.push_back(begin + i);
Expand Down Expand Up @@ -562,6 +562,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
sig_type.push_back(arg_value.back()->getType());
}
}
llvm::Type *return_type = LLVMType(op->type);
if (sig_type.size() > 0 && return_type != sig_type[0]) {
sig_type.insert(sig_type.begin(), return_type);
}
llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(), id, sig_type);
return builder_->CreateCall(f, arg_value);
Expand Down Expand Up @@ -628,6 +632,26 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
value->addIncoming(then_value, then_value_block);
value->addIncoming(else_value, else_value_block);
return value;
} else if (op->is_intrinsic(Call::reinterpret)) {
llvm::Type * target = LLVMType(op->type);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
} else if (op->is_intrinsic("vectorlow")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
return CreateVecSlice(v, 0, l/2);
} else if (op->is_intrinsic("vectorhigh")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
return CreateVecSlice(v, l/2, l/2);
} else if (op->is_intrinsic("vectorcombine")) {
llvm::Value *v0 = MakeValue(op->args[0]);
llvm::Value *v1 = MakeValue(op->args[1]);
int num_elems = static_cast<int>(v0->getType()->getVectorNumElements()) * 2;
std::vector<unsigned> indices;
for (int i = 0; i < num_elems; ++i) {
indices.push_back(i);
}
return builder_->CreateShuffleVector(v0, v1, indices);
} else {
LOG(FATAL) << "unknown intrinsic " << op->name;
return nullptr;
Expand Down
38 changes: 34 additions & 4 deletions src/codegen/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,41 @@ class LLVMModuleNode final : public runtime::ModuleNode {
}

std::string GetSource(const std::string& format) final {
std::string fmt = runtime::GetFileFormat("", format);
std::string type_str;
llvm::raw_string_ostream rso(type_str);
CHECK(mptr_ != nullptr);
mptr_->print(rso, nullptr);
return rso.str();
llvm::SmallString<256> str;
llvm::raw_svector_ostream rso(str);

if (fmt == "s" || fmt == "asm") {
#if TVM_LLVM_VERSION <= 60
std::unique_ptr<llvm::Module> m = llvm::CloneModule(mptr_);
#else
std::unique_ptr<llvm::Module> m = llvm::CloneModule(*mptr_);
#endif
llvm::legacy::PassManager pass;
CHECK(tm_);
#if TVM_LLVM_VERSION <= 60
CHECK(tm_->addPassesToEmitFile(
pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#else
CHECK(tm_->addPassesToEmitFile(
pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#endif
pass.run(*m);
return rso.str().str();
} else if (fmt == "" || fmt == "ll") {
std::string type_str;
llvm::raw_string_ostream rso(type_str);
CHECK(mptr_ != nullptr);
mptr_->print(rso, nullptr);
return rso.str();
} else {
LOG(FATAL) << "Do not know how to get source code with format: "
<< format << "\'";
}
return "";
}

void Init(const Array<LoweredFunc>& funcs, std::string target) {
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_codegen_arm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import tvm
import re
import os
import ctypes

def test_popcount():
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'

def check_correct_assembly(type, elements, counts):
n = tvm.convert(elements)
A = tvm.placeholder(n, dtype=type, name='A')
B = tvm.compute(A.shape, lambda i: tvm.popcount(A[i]), name='B')
s = tvm.create_schedule(B.op)
s[B].vectorize(s[B].op.axis[0])
f = tvm.build(s, [A, B], target)

# Verify we see the correct number of vpaddl and vcnt instructions in the assembly
assembly = f.get_source('asm')
matches = re.findall("vpaddl", assembly)
assert (len(matches) == counts)
matches = re.findall("vcnt", assembly)
assert (len(matches) == 1)
check_correct_assembly('uint16', 8, 1)
check_correct_assembly('uint16', 4, 1)
check_correct_assembly('uint32', 4, 2)
check_correct_assembly('uint32', 2, 2)
check_correct_assembly('uint64', 2, 3)

if __name__ == "__main__":
test_popcount()

0 comments on commit 53d65da

Please sign in to comment.