Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Conv2D BF16 BWD oneDNN kernel #38507

Merged
merged 6 commits into from
Dec 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,11 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, BF16,
ops::kConvMKLDNNFP32,
jakpiase marked this conversation as resolved.
Show resolved Hide resolved
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32,
Expand Down
156 changes: 141 additions & 15 deletions python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import struct

import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16, OpTestTool
from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2DOp


Expand All @@ -31,7 +31,7 @@ def conv2d_residual_naive(out, residual):

@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestConv2DBf16Op(TestConv2DOp):
class TestConv2DBF16Op(TestConv2DOp):
def setUp(self):
self.op_type = "conv2d"
self.use_cudnn = False
Expand All @@ -51,15 +51,19 @@ def setUp(self):
self.init_data_type()
self.init_force_fp32_output()

conv2d_param = {
self.conv2d_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}

self.input = np.random.random(self.input_size).astype(np.float32)
self.filter = np.random.random(self.filter_size).astype(np.float32)
conv_out, _, _, _, _ = conv2d_forward_naive(self.input, self.filter,
self.groups, conv2d_param)

self.inputs_fp32 = {'Input': self.input, 'Filter': self.filter}

conv_out, _, _, _, _ = conv2d_forward_naive(
self.input, self.filter, self.groups, self.conv2d_param)
self.conv_output_float = conv_out

if self.fuse_residual:
Expand All @@ -71,12 +75,16 @@ def setUp(self):
self.outputs = {'Output': self.conv_output}
elif self.force_fp32_output:
self.outputs = {'Output': self.conv_output_float.astype(np.float32)}
else:
self.outputs = {
'Output': convert_float_to_uint16(self.conv_output_float)
}

if self.input_type is not np.float32:
self.input = convert_float_to_uint16(self.input)

self.inputs = {
'Input': self.input.view(self.input_type),
'Input': self.input,
'Filter': OpTest.np_dtype_to_fluid_dtype(
self.filter.astype(self.weight_type))
}
Expand Down Expand Up @@ -111,14 +119,18 @@ def test_check_grad_no_input(self):

def init_test_case(self):
TestConv2DOp.init_test_case(self)
self.input_size = [1, 1, 5, 5] # NCHW
self.input_size = [1, 6, 12, 12] # NCHW
f_c = self.input_size[1] // self.groups
self.input_residual_size = [1, 2, 3, 3]
self.filter_size = [2, f_c, 3, 3]
o_c = 15
self.input_residual_size = [1, o_c, 10, 10]
self.filter_size = [o_c, f_c, 3, 3]

def init_padding(self):
pass

def init_data_type(self):
self.weight_type = np.float32
self.input_type = np.float32
self.input_type = np.uint16

def init_force_fp32_output(self):
self.force_fp32_output = False
Expand All @@ -130,7 +142,121 @@ def init_fuse_residual(self):
self.fuse_residual = True


class TestConv2D(TestConv2DBf16Op):
@OpTestTool.skip_if_not_cpu_bf16()
class TestConv2DWithGradBF16Op(TestConv2DBF16Op):
def init_fuse_relu(self):
self.fuse_activation = None

def init_fuse_residual(self):
self.fuse_residual = None

def test_check_grad(self):
dout = self.conv_output_float
x = self.inputs_fp32['Input']
w = self.inputs_fp32['Filter']

dx, dweights = conv_backward(dout, x, w, self.conv2d_param)

self.check_grad_with_place(
core.CPUPlace(), ["Input", "Filter"],
"Output",
user_defined_grads=[dx, dweights],
user_defined_grad_outputs=[convert_float_to_uint16(dout)])

def test_check_grad_no_filter(self):
dout = self.conv_output_float
x = self.inputs_fp32['Input']
w = self.inputs_fp32['Filter']

dx, _ = conv_backward(dout, x, w, self.conv2d_param)

self.check_grad_with_place(
core.CPUPlace(), ["Input"],
"Output",
set(['Filter']),
user_defined_grads=[dx],
user_defined_grad_outputs=[convert_float_to_uint16(dout)])

def test_check_grad_no_input(self):
dout = self.conv_output_float
x = self.inputs_fp32['Input']
w = self.inputs_fp32['Filter']

_, dweights = conv_backward(dout, x, w, self.conv2d_param)

self.check_grad_with_place(
core.CPUPlace(), ["Filter"],
"Output",
set(['Input']),
user_defined_grads=[dweights],
user_defined_grad_outputs=[convert_float_to_uint16(dout)])


def conv_backward(dout, x, w, params):
padding = params['pad'][0]
stride = params['stride']

dx = np.zeros_like(x)
dweights = np.zeros_like(w)

N, IC, H, W = x.shape
OC, _, KH, KW = w.shape

H_out = int(1 + (H + 2 * padding - KH) / stride[0])
W_out = int(1 + (W + 2 * padding - KW) / stride[1])

x_padded = np.pad(x, ((0, ), (0, ), (padding, ), (padding, )), 'constant')

for n in range(N):
for oc in range(OC):
for i in range(KH):
for j in range(KW):
for k in range(H_out):
for l in range(W_out):
for ic in range(IC):
dweights[oc, ic, i, j] += x_padded[
n, ic, i + k * stride[0], j + l * stride[
1]] * dout[n, oc, k, l]

dx_padded = np.pad(dx, ((0, ), (0, ), (padding, ), (padding, )), 'constant')

w_ = np.zeros_like(w)
for i in range(KH):
for j in range(KW):
w_[:, :, i, j] = w[:, :, KH - i - 1, KW - j - 1]

for n in range(N):
for oc in range(OC):
for i in range(H_out):
for j in range(W_out):
for kh in range(KH):
for kw in range(KW):
for ic in range(IC):
dx_padded[n, ic, stride[0] * i + kh, stride[1] *
j + kw] += dout[n, oc, i, j] * w[
oc, ic, kh, kw]

if padding == 0:
dx = dx_padded
else:
dx = dx_padded[:, :, padding:-padding, padding:-padding]

return dx.astype(np.float32), dweights.astype(np.float32)


class TestConv2DBF16WithPadding1(TestConv2DWithGradBF16Op):
def init_test_case(self):
TestConv2DWithGradBF16Op.init_test_case(self)
self.pad = [1, 1]


class TestConv2DBF16WithStride2(TestConv2DWithGradBF16Op):
def init_test_case(self):
TestConv2DWithGradBF16Op.init_test_case(self)
self.stride = [2, 3]


class TestConv2D(TestConv2DBF16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
Expand All @@ -156,7 +282,7 @@ def init_group(self):
self.groups = 3


class TestWithStride(TestConv2DBf16Op):
class TestWithStride(TestConv2DBF16Op):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
Expand All @@ -170,7 +296,7 @@ def init_data_type(self):
self.input_type = np.uint16


class TestWithDilations(TestConv2DBf16Op):
class TestWithDilations(TestConv2DBF16Op):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
Expand All @@ -185,7 +311,7 @@ def init_data_type(self):
self.input_type = np.uint16


class TestWith1x1ForceFP32Output(TestConv2DBf16Op):
class TestWith1x1ForceFP32Output(TestConv2DBF16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
Expand All @@ -201,7 +327,7 @@ def init_fuse_residual(self):
self.fuse_residual = False


class TestWithInput1x1Filter1x1(TestConv2DBf16Op):
class TestWithInput1x1Filter1x1(TestConv2DBF16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
Expand Down