Skip to content

Commit

Permalink
Added Conv2D BF16 BWD oneDNN kernel (#38507)
Browse files Browse the repository at this point in the history
* working test for padding only

* added full conv2d grad kernel

* removed some trash

* minor change

* Ci fix

* format fix
  • Loading branch information
jakpiase committed Dec 30, 2021
1 parent 04496d8 commit ed8ba01
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 15 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,11 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, BF16,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32,
Expand Down
156 changes: 141 additions & 15 deletions python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import struct

import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16, OpTestTool
from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2DOp


Expand All @@ -31,7 +31,7 @@ def conv2d_residual_naive(out, residual):

@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestConv2DBf16Op(TestConv2DOp):
class TestConv2DBF16Op(TestConv2DOp):
def setUp(self):
self.op_type = "conv2d"
self.use_cudnn = False
Expand All @@ -51,15 +51,19 @@ def setUp(self):
self.init_data_type()
self.init_force_fp32_output()

conv2d_param = {
self.conv2d_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}

self.input = np.random.random(self.input_size).astype(np.float32)
self.filter = np.random.random(self.filter_size).astype(np.float32)
conv_out, _, _, _, _ = conv2d_forward_naive(self.input, self.filter,
self.groups, conv2d_param)

self.inputs_fp32 = {'Input': self.input, 'Filter': self.filter}

conv_out, _, _, _, _ = conv2d_forward_naive(
self.input, self.filter, self.groups, self.conv2d_param)
self.conv_output_float = conv_out

if self.fuse_residual:
Expand All @@ -71,12 +75,16 @@ def setUp(self):
self.outputs = {'Output': self.conv_output}
elif self.force_fp32_output:
self.outputs = {'Output': self.conv_output_float.astype(np.float32)}
else:
self.outputs = {
'Output': convert_float_to_uint16(self.conv_output_float)
}

if self.input_type is not np.float32:
self.input = convert_float_to_uint16(self.input)

self.inputs = {
'Input': self.input.view(self.input_type),
'Input': self.input,
'Filter': OpTest.np_dtype_to_fluid_dtype(
self.filter.astype(self.weight_type))
}
Expand Down Expand Up @@ -111,14 +119,18 @@ def test_check_grad_no_input(self):

def init_test_case(self):
TestConv2DOp.init_test_case(self)
self.input_size = [1, 1, 5, 5] # NCHW
self.input_size = [1, 6, 12, 12] # NCHW
f_c = self.input_size[1] // self.groups
self.input_residual_size = [1, 2, 3, 3]
self.filter_size = [2, f_c, 3, 3]
o_c = 15
self.input_residual_size = [1, o_c, 10, 10]
self.filter_size = [o_c, f_c, 3, 3]

def init_padding(self):
pass

def init_data_type(self):
self.weight_type = np.float32
self.input_type = np.float32
self.input_type = np.uint16

def init_force_fp32_output(self):
self.force_fp32_output = False
Expand All @@ -130,7 +142,121 @@ def init_fuse_residual(self):
self.fuse_residual = True


class TestConv2D(TestConv2DBf16Op):
@OpTestTool.skip_if_not_cpu_bf16()
class TestConv2DWithGradBF16Op(TestConv2DBF16Op):
def init_fuse_relu(self):
self.fuse_activation = None

def init_fuse_residual(self):
self.fuse_residual = None

def test_check_grad(self):
dout = self.conv_output_float
x = self.inputs_fp32['Input']
w = self.inputs_fp32['Filter']

dx, dweights = conv_backward(dout, x, w, self.conv2d_param)

self.check_grad_with_place(
core.CPUPlace(), ["Input", "Filter"],
"Output",
user_defined_grads=[dx, dweights],
user_defined_grad_outputs=[convert_float_to_uint16(dout)])

def test_check_grad_no_filter(self):
dout = self.conv_output_float
x = self.inputs_fp32['Input']
w = self.inputs_fp32['Filter']

dx, _ = conv_backward(dout, x, w, self.conv2d_param)

self.check_grad_with_place(
core.CPUPlace(), ["Input"],
"Output",
set(['Filter']),
user_defined_grads=[dx],
user_defined_grad_outputs=[convert_float_to_uint16(dout)])

def test_check_grad_no_input(self):
dout = self.conv_output_float
x = self.inputs_fp32['Input']
w = self.inputs_fp32['Filter']

_, dweights = conv_backward(dout, x, w, self.conv2d_param)

self.check_grad_with_place(
core.CPUPlace(), ["Filter"],
"Output",
set(['Input']),
user_defined_grads=[dweights],
user_defined_grad_outputs=[convert_float_to_uint16(dout)])


def conv_backward(dout, x, w, params):
padding = params['pad'][0]
stride = params['stride']

dx = np.zeros_like(x)
dweights = np.zeros_like(w)

N, IC, H, W = x.shape
OC, _, KH, KW = w.shape

H_out = int(1 + (H + 2 * padding - KH) / stride[0])
W_out = int(1 + (W + 2 * padding - KW) / stride[1])

x_padded = np.pad(x, ((0, ), (0, ), (padding, ), (padding, )), 'constant')

for n in range(N):
for oc in range(OC):
for i in range(KH):
for j in range(KW):
for k in range(H_out):
for l in range(W_out):
for ic in range(IC):
dweights[oc, ic, i, j] += x_padded[
n, ic, i + k * stride[0], j + l * stride[
1]] * dout[n, oc, k, l]

dx_padded = np.pad(dx, ((0, ), (0, ), (padding, ), (padding, )), 'constant')

w_ = np.zeros_like(w)
for i in range(KH):
for j in range(KW):
w_[:, :, i, j] = w[:, :, KH - i - 1, KW - j - 1]

for n in range(N):
for oc in range(OC):
for i in range(H_out):
for j in range(W_out):
for kh in range(KH):
for kw in range(KW):
for ic in range(IC):
dx_padded[n, ic, stride[0] * i + kh, stride[1] *
j + kw] += dout[n, oc, i, j] * w[
oc, ic, kh, kw]

if padding == 0:
dx = dx_padded
else:
dx = dx_padded[:, :, padding:-padding, padding:-padding]

return dx.astype(np.float32), dweights.astype(np.float32)


class TestConv2DBF16WithPadding1(TestConv2DWithGradBF16Op):
def init_test_case(self):
TestConv2DWithGradBF16Op.init_test_case(self)
self.pad = [1, 1]


class TestConv2DBF16WithStride2(TestConv2DWithGradBF16Op):
def init_test_case(self):
TestConv2DWithGradBF16Op.init_test_case(self)
self.stride = [2, 3]


class TestConv2D(TestConv2DBF16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
Expand All @@ -156,7 +282,7 @@ def init_group(self):
self.groups = 3


class TestWithStride(TestConv2DBf16Op):
class TestWithStride(TestConv2DBF16Op):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
Expand All @@ -170,7 +296,7 @@ def init_data_type(self):
self.input_type = np.uint16


class TestWithDilations(TestConv2DBf16Op):
class TestWithDilations(TestConv2DBF16Op):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
Expand All @@ -185,7 +311,7 @@ def init_data_type(self):
self.input_type = np.uint16


class TestWith1x1ForceFP32Output(TestConv2DBf16Op):
class TestWith1x1ForceFP32Output(TestConv2DBF16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
Expand All @@ -201,7 +327,7 @@ def init_fuse_residual(self):
self.fuse_residual = False


class TestWithInput1x1Filter1x1(TestConv2DBf16Op):
class TestWithInput1x1Filter1x1(TestConv2DBF16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
Expand Down

0 comments on commit ed8ba01

Please sign in to comment.