Skip to content

Commit

Permalink
Fix test and doc (#44735)
Browse files Browse the repository at this point in the history
* fix test and doc
  • Loading branch information
zhangkaihuo committed Aug 1, 2022
1 parent cd94be6 commit 3e8708b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 66 deletions.
131 changes: 67 additions & 64 deletions python/paddle/fluid/tests/unittests/test_sparse_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,87 +18,90 @@
import paddle
from paddle.incubate.sparse import nn
import paddle.fluid as fluid
from paddle.fluid.framework import _test_eager_guard
import copy


class TestSparseBatchNorm(unittest.TestCase):

def test(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
with _test_eager_guard():
paddle.seed(0)
channels = 4
shape = [2, 3, 6, 6, channels]
#there is no zero in dense_x
dense_x = paddle.randn(shape)
dense_x.stop_gradient = False

batch_norm = paddle.nn.BatchNorm3D(channels, data_format="NDHWC")
dense_y = batch_norm(dense_x)
dense_y.backward(dense_y)

sparse_dim = 4
dense_x2 = copy.deepcopy(dense_x)
dense_x2.stop_gradient = False
sparse_x = dense_x2.to_sparse_coo(sparse_dim)
sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels)
# set same params
sparse_batch_norm._mean.set_value(batch_norm._mean)
sparse_batch_norm._variance.set_value(batch_norm._variance)
sparse_batch_norm.weight.set_value(batch_norm.weight)

sparse_y = sparse_batch_norm(sparse_x)
# compare the result with dense batch_norm
assert np.allclose(dense_y.flatten().numpy(),
sparse_y.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)

# test backward
sparse_y.backward(sparse_y)
assert np.allclose(dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
paddle.seed(0)
channels = 4
shape = [2, 3, 6, 6, channels]
#there is no zero in dense_x
dense_x = paddle.randn(shape)
dense_x.stop_gradient = False

batch_norm = paddle.nn.BatchNorm3D(channels, data_format="NDHWC")
dense_y = batch_norm(dense_x)
dense_y.backward(dense_y)

sparse_dim = 4
dense_x2 = copy.deepcopy(dense_x)
dense_x2.stop_gradient = False
sparse_x = dense_x2.to_sparse_coo(sparse_dim)
sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels)
# set same params
sparse_batch_norm._mean.set_value(batch_norm._mean)
sparse_batch_norm._variance.set_value(batch_norm._variance)
sparse_batch_norm.weight.set_value(batch_norm.weight)

sparse_y = sparse_batch_norm(sparse_x)
# compare the result with dense batch_norm
assert np.allclose(dense_y.flatten().numpy(),
sparse_y.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)

# test backward
sparse_y.backward(sparse_y)
assert np.allclose(dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})

def test_error_layout(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
shape = [2, 3, 6, 6, 3]
x = paddle.randn(shape)
sparse_x = x.to_sparse_coo(4)
sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(
3, data_format='NCDHW')
sparse_batch_norm(sparse_x)
with self.assertRaises(ValueError):
shape = [2, 3, 6, 6, 3]
x = paddle.randn(shape)
sparse_x = x.to_sparse_coo(4)
sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(
3, data_format='NCDHW')
sparse_batch_norm(sparse_x)

def test2(self):
with _test_eager_guard():
paddle.seed(123)
channels = 3
x_data = paddle.randn((1, 6, 6, 6, channels)).astype('float32')
dense_x = paddle.to_tensor(x_data)
sparse_x = dense_x.to_sparse_coo(4)
batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels)
batch_norm_out = batch_norm(sparse_x)
print(batch_norm_out.shape)
# [1, 6, 6, 6, 3]
paddle.seed(123)
channels = 3
x_data = paddle.randn((1, 6, 6, 6, channels)).astype('float32')
dense_x = paddle.to_tensor(x_data)
sparse_x = dense_x.to_sparse_coo(4)
batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels)
batch_norm_out = batch_norm(sparse_x)
dense_bn = paddle.nn.BatchNorm1D(channels)
dense_x = dense_x.reshape((-1, dense_x.shape[-1]))
dense_out = dense_bn(dense_x)
assert np.allclose(dense_out.numpy(), batch_norm_out.values().numpy())
# [1, 6, 6, 6, 3]


class TestSyncBatchNorm(unittest.TestCase):

def test_sync_batch_norm(self):
with _test_eager_guard():
x = np.array([[[[0.3, 0.4], [0.3, 0.07]],
[[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
x = paddle.to_tensor(x)
x = x.to_sparse_coo(len(x.shape) - 1)

if paddle.is_compiled_with_cuda():
sync_batch_norm = nn.SyncBatchNorm(2)
hidden1 = sync_batch_norm(x)
print(hidden1)
x = np.array([[[[0.3, 0.4], [0.3, 0.07]],
[[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
x = paddle.to_tensor(x)
sparse_x = x.to_sparse_coo(len(x.shape) - 1)

if paddle.is_compiled_with_cuda():
sparse_sync_bn = nn.SyncBatchNorm(2)
sparse_hidden = sparse_sync_bn(sparse_x)

dense_sync_bn = paddle.nn.SyncBatchNorm(2)
x = x.reshape((-1, x.shape[-1]))
dense_hidden = dense_sync_bn(x)
assert np.allclose(sparse_hidden.values().numpy(),
dense_hidden.numpy())

def test_convert(self):
base_model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5),
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/incubate/sparse/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm):
Shapes:
input: Tensor that the dimension from 2 to 5.
output: Tensor with the same shape as input.
Examples:
Expand Down Expand Up @@ -278,7 +279,7 @@ def forward(self, x):

@classmethod
def convert_sync_batchnorm(cls, layer):
"""
r"""
Helper function to convert :class: `paddle.incubate.sparse.nn.BatchNorm` layers in the model to :class: `paddle.incubate.sparse.nn.SyncBatchNorm` layers.
Parameters:
Expand All @@ -290,13 +291,14 @@ def convert_sync_batchnorm(cls, layer):
Examples:
.. code-block:: python
import paddle
import paddle.incubate.sparse.nn as nn
model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5))
sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
"""

layer_output = layer
if isinstance(layer, _BatchNormBase):
if layer._weight_attr != None and not isinstance(
Expand Down

0 comments on commit 3e8708b

Please sign in to comment.