From 831fe4850f22b87ea8f8b53b2b2fcb628df41bea Mon Sep 17 00:00:00 2001 From: Meiyim Date: Fri, 12 Mar 2021 17:57:33 +0800 Subject: [PATCH] [npu] support npu kernel for `less than` --- .../operators/controlflow/compare_op_npu.cc | 32 ++++++++++- .../unittests/npu/test_compare_op_npu.py | 55 +++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/controlflow/compare_op_npu.cc b/paddle/fluid/operators/controlflow/compare_op_npu.cc index 58401302bc3a7..f9f65dba69092 100644 --- a/paddle/fluid/operators/controlflow/compare_op_npu.cc +++ b/paddle/fluid/operators/controlflow/compare_op_npu.cc @@ -1,5 +1,4 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -21,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/operators/controlflow/compare_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/npu_op_runner.h" +#ifdef PADDLE_WITH_ASCEND_CL namespace paddle { namespace operators { @@ -42,6 +42,23 @@ class EqualNPUKernel : public framework::OpKernel { } }; +template +class LessThanNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + // int axis = context.Attr("axis"); + z->mutable_data(ctx.GetPlace()); // allocate + auto runner = NpuOpRunner("Less", {*x, *y}, {*z}); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + } // namespace operators } // namespace paddle @@ -51,3 +68,16 @@ namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL(equal, ops::EqualNPUKernel, ops::EqualNPUKernel, ops::EqualNPUKernel); + +REGISTER_OP_NPU_KERNEL( + less_than, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel); + +#endif diff --git a/python/paddle/fluid/tests/unittests/npu/test_compare_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_compare_op_npu.py index c82897d54b8f6..5a8e0efdeed97 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_compare_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_compare_op_npu.py @@ -56,6 +56,36 @@ def test_check_output(self): self.check_output_with_place(self.place, check_dygraph=False) +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestLessthan(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "less_than" + self.place = paddle.NPUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + out = x < y # all elements are not equal + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(x), + 'Y': OpTest.np_dtype_to_fluid_dtype(y) + } + self.outputs = {'Out': out} + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + class TestEqual2(TestEqual): def setUp(self): self.set_npu() @@ -76,6 +106,26 @@ def setUp(self): self.outputs = {'Out': out} +class TestLessthan2(TestLessthan): + def setUp(self): + self.set_npu() + self.op_type = "less_than" + self.place = paddle.NPUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + y = x.copy() + y[0][1] = 1 + out = x < y # all elements are equal, except position [0][1] + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(x), + 'Y': OpTest.np_dtype_to_fluid_dtype(y) + } + self.outputs = {'Out': out} + + class TestEqual2FP16(TestEqual2): def init_dtype(self): self.dtype = np.float16 @@ -86,5 +136,10 @@ def init_dtype(self): self.dtype = np.int32 +class TestLessthan2FP16(TestLessthan2): + def init_dtype(self): + self.dtype = np.float16 + + if __name__ == '__main__': unittest.main()