Skip to content

Commit

Permalink
【Hackathon No57】add_fp16_bf16_for_dot & bf16_for_cross (#52426)
Browse files Browse the repository at this point in the history
* add_fp_bf_for_dot & bf_for_cross

* fix error

* fix some error

* fix some error

* change something

* fix magic number
  • Loading branch information
Difers committed Apr 13, 2023
1 parent e0e044c commit 205094f
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 6 deletions.
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/cross_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ PD_REGISTER_KERNEL(cross_grad,
ALL_LAYOUT,
phi::CrossGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
int,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/cross_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ PD_REGISTER_KERNEL(cross,
ALL_LAYOUT,
phi::CrossKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
int,
Expand Down
6 changes: 5 additions & 1 deletion paddle/phi/kernels/gpu/dot_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/phi/kernels/dot_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/dot_grad_kernel_impl.h"

Expand All @@ -28,4 +30,6 @@ PD_REGISTER_KERNEL(dot_grad,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>,
phi::dtype::float16,
phi::dtype::bfloat16) {}
6 changes: 5 additions & 1 deletion paddle/phi/kernels/gpu/dot_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "paddle/phi/kernels/dot_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"

Expand Down Expand Up @@ -61,4 +63,6 @@ PD_REGISTER_KERNEL(dot,
int,
int64_t,
complex64,
complex128) {}
complex128,
phi::dtype::float16,
phi::dtype::bfloat16) {}
52 changes: 50 additions & 2 deletions python/paddle/fluid/tests/unittests/test_cross_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle import fluid
from paddle.fluid import Program, program_guard
from paddle.fluid import Program, core, program_guard


class TestCrossOp(OpTest):
Expand Down Expand Up @@ -65,6 +65,9 @@ def init_output(self):
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}


@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestCrossFP16Op(TestCrossOp):
def initTestCase(self):
self.shape = (2048, 3)
Expand All @@ -77,6 +80,51 @@ def init_output(self):
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestCrossBF16Op(OpTest):
def setUp(self):
self.op_type = "cross"
self.python_api = paddle.cross
self.initTestCase()
self.x = np.random.random(self.shape).astype(np.float32)
self.y = np.random.random(self.shape).astype(np.float32)
self.inputs = {
'X': convert_float_to_uint16(self.x),
'Y': convert_float_to_uint16(self.y),
}
self.init_output()

def initTestCase(self):
self.attrs = {'dim': -2}
self.dtype = np.uint16
self.shape = (1024, 3, 1)

def init_output(self):
x = np.squeeze(self.x, 2)
y = np.squeeze(self.y, 2)
z_list = []
for i in range(1024):
z_list.append(np.cross(x[i], y[i]))
out = np.array(z_list).astype(np.float32).reshape(self.shape)
self.outputs = {'Out': convert_float_to_uint16(out)}

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_output_with_place(place)

def test_check_grad_normal(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_grad_with_place(place, ['X', 'Y'], 'Out')


class TestCrossAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array(
Expand Down
199 changes: 197 additions & 2 deletions python/paddle/fluid/tests/unittests/test_dot_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle import fluid
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_check_grad_ingore_y(self):
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [121]).astype(self.dtype)
self.y = np.random.uniform(1, 3, [121]).astype(self.dtype)
self.out = np.dot(self.x, self.y)
self.out = np.dot(self.x, self.y).astype(self.dtype)

def init_dtype(self):
self.dtype = np.float64
Expand Down Expand Up @@ -314,6 +314,201 @@ def test_check_grad_ingore_y(self):
)


@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestDotFP16Op(OpTest):
def setUp(self):
self.op_type = "dot"
self.python_api = paddle.dot
self.init_dtype()
self.init_input_output()

self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.outputs = {'Out': self.out}
self.attrs = {}

def init_dtype(self):
self.dtype = np.float16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=0.125)

def test_check_grad_normal(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(place, ['X', 'Y'], 'Out')

def test_check_grad_ingore_x(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['Y'], 'Out', no_grad_set=set("X")
)

def test_check_grad_ingore_y(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['X'], 'Out', no_grad_set=set("Y")
)

def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [121]).astype(self.dtype)
self.y = np.random.uniform(1, 3, [121]).astype(self.dtype)
self.out = np.dot(self.x, self.y)


@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class DotFP16OpBatch(TestDotFP16Op):
def init_input_output(self):
self.x = (
np.random.uniform(0.1, 1, [132])
.astype(self.dtype)
.reshape([11, 12])
)
self.y = (
np.random.uniform(1, 3, [132]).astype(self.dtype).reshape([11, 12])
)
self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1])


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestDotBF16Op(OpTest):
def setUp(self):
self.op_type = "dot"
self.python_api = paddle.dot
self.init_dtype()
self.init_input_output()

self.inputs = {
'X': convert_float_to_uint16(self.x),
'Y': convert_float_to_uint16(self.y),
}
self.outputs = {'Out': convert_float_to_uint16(self.out)}
self.attrs = {}

def init_dtype(self):
self.dtype = np.uint16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_output_with_place(place, atol=0.5)

def test_check_grad_normal(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_grad_with_place(
place,
['X', 'Y'],
'Out',
user_defined_grads=[self.inputs['Y'], self.inputs['X']],
)

def test_check_grad_ingore_x(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_grad_with_place(
place,
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.inputs['X']],
)

def test_check_grad_ingore_y(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_grad_with_place(
place,
['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[self.inputs['Y']],
)

def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [121]).astype(np.float32)
self.y = np.random.uniform(1, 3, [121]).astype(np.float32)
self.out = np.dot(self.x, self.y)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class DotBF16OpBatch(TestDotBF16Op):
def init_input_output(self):
self.x = (
np.random.uniform(0.1, 1, [132])
.astype(np.float32)
.reshape([11, 12])
)
self.y = (
np.random.uniform(1, 3, [132]).astype(np.float32).reshape([11, 12])
)
self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1])

def test_check_grad_normal(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_grad_with_place(
place,
['X', 'Y'],
'Out',
user_defined_grads=[
self.y / self.y.shape[0],
self.x / self.x.shape[0],
],
)

def test_check_grad_ingore_x(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_grad_with_place(
place,
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.x / self.x.shape[0]],
)

def test_check_grad_ingore_y(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_grad_with_place(
place,
['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[self.y / self.y.shape[0]],
)


if __name__ == '__main__':
paddle.enable_static()
unittest.main()

0 comments on commit 205094f

Please sign in to comment.