Skip to content

Commit

Permalink
Addition of fp16 type support for Compare OP (#44405)
Browse files Browse the repository at this point in the history
* first commit

* add fp16 ctest files for compare op

* add cpu register of float16 for compare ops
  • Loading branch information
JamesLim-sy committed Aug 4, 2022
1 parent c693a02 commit 6506668
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
18 changes: 12 additions & 6 deletions paddle/phi/kernels/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ PD_REGISTER_KERNEL(less_than,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(less_equal,
CPU,
ALL_LAYOUT,
Expand All @@ -90,7 +91,8 @@ PD_REGISTER_KERNEL(less_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_than,
CPU,
ALL_LAYOUT,
Expand All @@ -100,7 +102,8 @@ PD_REGISTER_KERNEL(greater_than,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_equal,
CPU,
ALL_LAYOUT,
Expand All @@ -110,7 +113,8 @@ PD_REGISTER_KERNEL(greater_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal,
CPU,
ALL_LAYOUT,
Expand All @@ -120,7 +124,8 @@ PD_REGISTER_KERNEL(equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(not_equal,
CPU,
ALL_LAYOUT,
Expand All @@ -130,7 +135,8 @@ PD_REGISTER_KERNEL(not_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}

PD_REGISTER_KERNEL(equal_all,
CPU,
Expand Down
18 changes: 12 additions & 6 deletions paddle/phi/kernels/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ PD_REGISTER_KERNEL(less_than,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(less_equal,
KPS,
ALL_LAYOUT,
Expand All @@ -123,7 +124,8 @@ PD_REGISTER_KERNEL(less_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_than,
KPS,
ALL_LAYOUT,
Expand All @@ -133,7 +135,8 @@ PD_REGISTER_KERNEL(greater_than,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_equal,
KPS,
ALL_LAYOUT,
Expand All @@ -143,7 +146,8 @@ PD_REGISTER_KERNEL(greater_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal,
KPS,
ALL_LAYOUT,
Expand All @@ -153,7 +157,8 @@ PD_REGISTER_KERNEL(equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(not_equal,
KPS,
ALL_LAYOUT,
Expand All @@ -163,7 +168,8 @@ PD_REGISTER_KERNEL(not_equal,
int,
int64_t,
float,
double) {}
double,
phi::dtype::float16) {}

PD_REGISTER_KERNEL(equal_all,
KPS,
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/fluid/tests/unittests/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ def test_errors(self):
globals()[cls_name] = Cls


for _type_name in {'float32', 'float64', 'int32', 'int64'}:
for _type_name in {'float32', 'float64', 'int32', 'int64', 'float16'}:
if _type_name == 'float64' and core.is_compiled_with_rocm():
_type_name = 'float32'
if _type_name == 'float16' and (not core.is_compiled_with_cuda()):
continue

create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
Expand Down

0 comments on commit 6506668

Please sign in to comment.