Skip to content

Commit

Permalink
[npu] support npu kernel for less than
Browse files Browse the repository at this point in the history
  • Loading branch information
Meiyim committed Mar 12, 2021
1 parent 15823bb commit 831fe48
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
32 changes: 31 additions & 1 deletion paddle/fluid/operators/controlflow/compare_op_npu.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Expand All @@ -21,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#ifdef PADDLE_WITH_ASCEND_CL

namespace paddle {
namespace operators {
Expand All @@ -42,6 +42,23 @@ class EqualNPUKernel : public framework::OpKernel<T> {
}
};

template <typename DeviceContext, typename T>
class LessThanNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
// int axis = context.Attr<int>("axis");
z->mutable_data<bool>(ctx.GetPlace()); // allocate
auto runner = NpuOpRunner("Less", {*x, *y}, {*z});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};

} // namespace operators
} // namespace paddle

Expand All @@ -51,3 +68,16 @@ namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(equal, ops::EqualNPUKernel<float>,
ops::EqualNPUKernel<plat::float16>,
ops::EqualNPUKernel<int>);

REGISTER_OP_NPU_KERNEL(
less_than,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext, double>,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext, uint8_t>,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext, int8_t>,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext, int64_t>);

#endif
55 changes: 55 additions & 0 deletions python/paddle/fluid/tests/unittests/npu/test_compare_op_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,36 @@ def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)


@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestLessthan(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "less_than"
self.place = paddle.NPUPlace(0)

self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = x < y # all elements are not equal

self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.outputs = {'Out': out}

def set_npu(self):
self.__class__.use_npu = True

def init_dtype(self):
self.dtype = np.float32

def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)


class TestEqual2(TestEqual):
def setUp(self):
self.set_npu()
Expand All @@ -76,6 +106,26 @@ def setUp(self):
self.outputs = {'Out': out}


class TestLessthan2(TestLessthan):
def setUp(self):
self.set_npu()
self.op_type = "less_than"
self.place = paddle.NPUPlace(0)

self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
y = x.copy()
y[0][1] = 1
out = x < y # all elements are equal, except position [0][1]

self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.outputs = {'Out': out}


class TestEqual2FP16(TestEqual2):
def init_dtype(self):
self.dtype = np.float16
Expand All @@ -86,5 +136,10 @@ def init_dtype(self):
self.dtype = np.int32


class TestLessthan2FP16(TestLessthan2):
def init_dtype(self):
self.dtype = np.float16


if __name__ == '__main__':
unittest.main()

0 comments on commit 831fe48

Please sign in to comment.