Skip to content

Commit

Permalink
【Hackathon No.55】 add channel_shuffle FP16/BF16 support and tests (#5…
Browse files Browse the repository at this point in the history
…1884)

* No55 add channel_shuffle FP16/BF16 support and tests
  • Loading branch information
superwinner1 committed Apr 13, 2023
1 parent 205094f commit 48ccb78
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 5 deletions.
1 change: 0 additions & 1 deletion paddle/phi/kernels/channel_shuffle_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <string>

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/channel_shuffle_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(channel_shuffle_grad,
ALL_LAYOUT,
phi::ChannelShuffleGradKernel,
float,
double) {}
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/channel_shuffle_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(channel_shuffle,
ALL_LAYOUT,
phi::ChannelShuffleKernel,
float,
double) {}
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
56 changes: 54 additions & 2 deletions python/paddle/fluid/tests/unittests/test_channel_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -45,6 +45,7 @@ def channel_shuffle_np(x, groups, data_format="NCHW"):
class TestChannelShuffleOp(OpTest):
def setUp(self):
self.op_type = "channel_shuffle"
self.init_dtype()
self.init_data_format()
n, c, h, w = 2, 9, 4, 4
self.python_api = paddle.nn.functional.channel_shuffle
Expand All @@ -56,13 +57,16 @@ def setUp(self):

groups = 3

x = np.random.random(shape).astype("float64")
x = np.random.random(shape).astype(self.dtype)
npresult = channel_shuffle_np(x, groups, self.format)

self.inputs = {'X': x}
self.outputs = {'Out': npresult}
self.attrs = {'groups': groups, "data_format": self.format}

def init_dtype(self):
self.dtype = 'float64'

def init_data_format(self):
self.format = "NCHW"

Expand Down Expand Up @@ -268,5 +272,53 @@ def error_data_format_layer():
self.assertRaises(ValueError, error_data_format_layer)


class TestChannelShuffleFP16OP(TestChannelShuffleOp):
def init_dtype(self):
self.dtype = np.float16


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestChannelShuffleBF16OP(OpTest):
def setUp(self):
self.op_type = "channel_shuffle"
self.init_data_format()
n, c, h, w = 2, 9, 4, 4
self.python_api = paddle.nn.functional.channel_shuffle
self.dtype = np.uint16
self.use_mkldnn = False

if self.format == "NCHW":
shape = [n, c, h, w]
if self.format == "NHWC":
shape = [n, h, w, c]

groups = 3

x = np.random.random(shape).astype('float32')
out = channel_shuffle_np(x, groups, self.format)
self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {'groups': groups, "data_format": self.format}
self.outputs = {'Out': convert_float_to_uint16(out)}

def init_data_format(self):
self.format = "NCHW"

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
'Out',
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 48ccb78

Please sign in to comment.