Skip to content

Commit

Permalink
[PluggableDevice] custom kernel supports multi cpp_dtype registering
Browse files Browse the repository at this point in the history
  • Loading branch information
Aganlengzi committed Feb 8, 2022
1 parent 41eb259 commit 8eb9a11
Show file tree
Hide file tree
Showing 3 changed files with 695 additions and 63 deletions.
114 changes: 76 additions & 38 deletions paddle/fluid/framework/custom_kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ limitations under the License. */
// user kernel function
namespace custom_kernel {

// Here we use dot <CPU, ANY, UINT8> for test
// This test will fail when these two kernels are aupported in framework
// Here we use fake_dot for test
// input 3: two Tensors and one std::vector<Tensor>
// attribute 11: fake_attributes
// output 2: one Tensor* and one std::vector<Tensor*>
template <typename T>
void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
template <typename T, typename Context>
void FakeDot(const Context& dev_ctx, const paddle::Tensor& x,
const paddle::Tensor& y,
const std::vector<paddle::Tensor>& fake_input_vec,
bool fake_attr_bool, int fake_attr_int, float fake_attr_float,
Expand Down Expand Up @@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
}
} // namespace custom_kernel

PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, UINT8,
custom_kernel::FakeDot<uint8_t>) {
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UINT8);
}
PD_REGISTER_KERNEL(fake_dot, CPU, ALL_LAYOUT, custom_kernel::FakeDot, float,
double, int, int64_t, int8_t, uint8_t) {}

// Upper code will store dot kernels info into OpKernelInfoMap
TEST(CustomKernel, custom_kernel_dot) {
std::string op_name = "dot";
std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY;
pten::DataType dtype = pten::DataType::UINT8;
pten::DataLayout layout = pten::DataLayout::ALL_LAYOUT;

// 1.custom kernel info parsed and store
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find("dot") !=
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find(op_name) !=
paddle::OpKernelInfoMap::Instance().GetMap().end());

// 2.info check
EXPECT_EQ(
1, static_cast<int>(paddle::OpKernelInfoMap::Instance()["dot"].size()));
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetBackend() ==
6, static_cast<int>(paddle::OpKernelInfoMap::Instance()[op_name].size()));
// index 0
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetBackend() ==
backend);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataLayout() ==
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataLayout() ==
layout);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataType() ==
dtype);

// 3.register
EXPECT_TRUE(pten::KernelFactory::Instance().kernels().end() !=
pten::KernelFactory::Instance().kernels().find("dot"));

pten::KernelKey kernel_key(backend, layout, dtype);
EXPECT_TRUE(
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) ==
pten::KernelFactory::Instance().kernels()["dot"].end());

EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataType() ==
pten::DataType::FLOAT32);
// index 5
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetBackend() ==
backend);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataLayout() ==
layout);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataType() ==
pten::DataType::UINT8);

// 3.before register
auto& kernel_factory_instance = pten::KernelFactory::Instance();
auto& kernels = pten::KernelFactory::Instance().kernels();
EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePtenKernel(op_name));

// mock fake_dot is supported by pten for HasCompatiblePtenKernel check while
// registering
auto& fake_dot_kernels = kernels[op_name];

EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) ==
fake_dot_kernels.end());

// register
paddle::framework::RegisterKernelWithMetaInfoMap(
paddle::OpKernelInfoMap::Instance());

EXPECT_TRUE(
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) !=
pten::KernelFactory::Instance().kernels()["dot"].end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) !=
fake_dot_kernels.end());

// 4.kernel select
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
op_name, kernel_key);
auto kernel = kernel_factory_instance.SelectKernelOrThrowError(
op_name, pten::KernelKey(backend, layout, pten::DataType::UINT8));

// 5.prepare parameters for kernel
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
Expand Down Expand Up @@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) {
// test OpKernelInfoHelper
TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
using OpKernelInfoHelper = paddle::framework::OpKernelInfoHelper;
std::string op_name = "dot";
std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY;
pten::DataType dtype = pten::DataType::UINT8;
pten::DataType dtype = pten::DataType::FLOAT32;

auto op_kernel_info = paddle::OpKernelInfoMap::Instance()[op_name][0];

Expand All @@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
OpKernelInfoHelper::GetKernelKey(op_kernel_info));

paddle::CustomKernelFunc kernel_fn =
PD_PT_KERNEL(custom_kernel::FakeDot<uint8_t>);
PD_PT_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(kernel_fn, OpKernelInfoHelper::GetKernelFn(op_kernel_info));

void* variadic_func = PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<uint8_t>);
void* variadic_func =
PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(variadic_func,
OpKernelInfoHelper::GetVariadicKernelFn(op_kernel_info));

Expand Down
Loading

1 comment on commit 8eb9a11

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.