Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick to 2.1] [Modify spectralnorm #32633] #32667

Merged
merged 1 commit into from
Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions python/paddle/fluid/tests/unittests/test_dygraph_spectral_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import collections
import paddle
import paddle.nn as nn
from paddle.nn.utils import spectral_norm


class TestDygraphSpectralNorm(unittest.TestCase):
def setUp(self):
self.init_test_case()
self.set_data()

def init_test_case(self):
self.batch_size = 3
self.data_desc = (['x', [2, 12, 12]], )
self.n_power_iterations = 1
self.eps = 1e-12
self.dim = None

def set_data(self):
self.data = collections.OrderedDict()
for desc in self.data_desc:
data_name = desc[0]
data_shape = desc[1]
data_value = np.random.random(
size=[self.batch_size] + data_shape).astype('float32')
self.data[data_name] = data_value

def spectral_normalize(self, weight, u, v, dim, power_iters, eps):
shape = weight.shape
weight_mat = weight.copy()
h = shape[dim]
w = np.prod(shape) // h
if dim != 0:
perm = [dim] + [d for d in range(len(shape)) if d != dim]
weight_mat = weight_mat.transpose(perm)
weight_mat = weight_mat.reshape((h, w))

u = u.reshape((h, 1))
v = v.reshape((w, 1))
for i in range(power_iters):
v = np.matmul(weight_mat.T, u)
v_norm = np.sqrt((v * v).sum())
v = v / (v_norm + eps)
u = np.matmul(weight_mat, v)
u_norm = np.sqrt((u * u).sum())
u = u / (u_norm + eps)
sigma = (u * np.matmul(weight_mat, v)).sum()
return weight / sigma

def test_check_output(self):
linear = paddle.nn.Conv2D(2, 1, 3)
before_weight = linear.weight.numpy().copy()
if self.dim == None:
if isinstance(linear, (nn.Conv1DTranspose, nn.Conv2DTranspose,
nn.Conv3DTranspose, nn.Linear)):
self.dim = 1
else:
self.dim = 0
else:
self.dim = (self.dim + len(before_weight)) % len(before_weight)

sn = spectral_norm(
linear,
n_power_iterations=self.n_power_iterations,
eps=self.eps,
dim=self.dim)
u = sn.weight_u.numpy().copy()
v = sn.weight_v.numpy().copy()
outputs = []
for name, data in self.data.items():
output = linear(paddle.to_tensor(data))
outputs.append(output.numpy())
self.actual_outputs = linear.weight.numpy()

expect_output = self.spectral_normalize(
before_weight, u, v, self.dim, self.n_power_iterations, self.eps)

for expect, actual in zip(expect_output, self.actual_outputs):
self.assertTrue(
np.allclose(
np.array(actual), np.array(expect), atol=0.001))


class TestDygraphWeightNormCase(TestDygraphSpectralNorm):
def init_test_case(self):
self.batch_size = 2
self.data_desc = (['x', [2, 3, 3]], )
self.n_power_iterations = 1
self.eps = 1e-12
self.dim = None


class TestDygraphWeightNormWithIterations(TestDygraphSpectralNorm):
def init_test_case(self):
self.batch_size = 2
self.data_desc = (['x', [2, 3, 3]], )
self.n_power_iterations = 2
self.eps = 1e-12
self.dim = None


class TestDygraphWeightNormWithDim(TestDygraphSpectralNorm):
def init_test_case(self):
self.batch_size = 2
self.data_desc = (['x', [2, 3, 3]], )
self.n_power_iterations = 1
self.eps = 1e-12
self.dim = 1


class TestDygraphWeightNormWithEps(TestDygraphSpectralNorm):
def init_test_case(self):
self.batch_size = 2
self.data_desc = (['x', [2, 3, 3]], )
self.n_power_iterations = 1
self.eps = 1e-10
self.dim = None


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
from .layer.vision import PixelShuffle # noqa: F401
from .layer.container import LayerDict # noqa: F401

from .utils.spectral_norm_hook import spectral_norm

# TODO: remove loss, keep it for too many used in unitests
from .layer import loss # noqa: F401
from ..fluid.dygraph.layers import Layer # noqa: F401
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/nn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .spectral_norm_hook import spectral_norm
from .weight_norm_hook import weight_norm, remove_weight_norm # noqa: F401

__all__ = [ #noqa
'weight_norm', 'remove_weight_norm'
'weight_norm', 'remove_weight_norm', 'spectral_norm'
]
210 changes: 210 additions & 0 deletions python/paddle/nn/utils/spectral_norm_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import numpy as np

import paddle
from ..layer.conv import Conv1DTranspose, Conv2DTranspose, Conv3DTranspose
from ..layer.common import Linear
from .. import functional as F

__all__ = []


def normal_(x, mean=0., std=1.):
temp_value = paddle.normal(mean, std, shape=x.shape)
x.set_value(temp_value)
return x


class SpectralNorm(object):
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
self.name = name
self.dim = dim
if n_power_iterations <= 0:
raise ValueError('Expected n_power_iterations to be positive, but '
'got n_power_iterations={}'.format(
n_power_iterations))
self.n_power_iterations = n_power_iterations
self.eps = eps

def reshape_weight_to_matrix(self, weight):
weight_mat = weight
if self.dim != 0:
# transpose dim to front
weight_mat = weight_mat.transpose([self.dim] + [
d for d in range(weight_mat.dim()) if d != self.dim
])

height = weight_mat.shape[0]

return weight_mat.reshape([height, -1])

def compute_weight(self, layer, do_power_iteration):
weight = getattr(layer, self.name + '_orig')
u = getattr(layer, self.name + '_u')
v = getattr(layer, self.name + '_v')
weight_mat = self.reshape_weight_to_matrix(weight)

if do_power_iteration:
with paddle.no_grad():
for _ in range(self.n_power_iterations):
v.set_value(
F.normalize(
paddle.matmul(
weight_mat,
u,
transpose_x=True,
transpose_y=False),
axis=0,
epsilon=self.eps, ))

u.set_value(
F.normalize(
paddle.matmul(weight_mat, v),
axis=0,
epsilon=self.eps, ))
if self.n_power_iterations > 0:
u = u.clone()
v = v.clone()

sigma = paddle.dot(u, paddle.mv(weight_mat, v))
weight = weight / sigma
return weight

def __call__(self, layer, inputs):
setattr(
layer,
self.name,
self.compute_weight(
layer, do_power_iteration=layer.training))

@staticmethod
def apply(layer, name, n_power_iterations, dim, eps):
for k, hook in layer._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
raise RuntimeError("Cannot register two spectral_norm hooks on "
"the same parameter {}".format(name))

fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = layer._parameters[name]

with paddle.no_grad():
weight_mat = fn.reshape_weight_to_matrix(weight)
h, w = weight_mat.shape

# randomly initialize u and v
u = layer.create_parameter([h])
u = normal_(u, 0., 1.)
v = layer.create_parameter([w])
v = normal_(v, 0., 1.)
u = F.normalize(u, axis=0, epsilon=fn.eps)
v = F.normalize(v, axis=0, epsilon=fn.eps)

# delete fn.name form parameters, otherwise you can not set attribute
del layer._parameters[fn.name]
layer.add_parameter(fn.name + "_orig", weight)
# still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an Parameter and
# gets added as a parameter. Instead, we register weight * 1.0 as a plain
# attribute.
setattr(layer, fn.name, weight * 1.0)
layer.register_buffer(fn.name + "_u", u)
layer.register_buffer(fn.name + "_v", v)
layer.register_forward_pre_hook(fn)
return fn


def spectral_norm(layer,
name='weight',
n_power_iterations=1,
eps=1e-12,
dim=None):
r"""
This spectral_norm layer applies spectral normalization to a parameter according to the
following Calculation:

Step 1:
Generate vector U in shape of [H], and V in shape of [W].
While H is the :attr:`dim` th dimension of the input weights,
and W is the product result of remaining dimensions.

Step 2:
:attr:`power_iters` should be a positive integer, do following
calculations with U and V for :attr:`power_iters` rounds.

.. math::

\mathbf{v} := \\frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}

\mathbf{u} := \\frac{\mathbf{W} \mathbf{v}}{\|\mathbf{W} \mathbf{v}\|_2}

Step 3:
Calculate :math:`\sigma(\mathbf{W})` and normalize weight values.

.. math::

\sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}

\mathbf{W} = \\frac{\mathbf{W}}{\sigma(\mathbf{W})}


Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .

Parameters:
layer(Layer): Layer of paddle, which has weight.
name(str, optional): Name of the weight parameter. Default: 'weight'.
n_power_iterations(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: None.

Returns:
The original layer with the spectral norm hook

Examples:
.. code-block:: python

from paddle.nn import Conv2D
from paddle.nn.utils import Spectralnorm

conv = Conv2D(3, 1, 3)
sn_conv = spectral_norm(conv)
print(sn_conv)
# Conv2D(3, 1, kernel_size=[3, 3], data_format=NCHW)
print(sn_conv.weight)
# Tensor(shape=[1, 3, 3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
# [[[[-0.21090528, 0.18563725, -0.14127982],
# [-0.02310637, 0.03197737, 0.34353802],
# [-0.17117859, 0.33152047, -0.28408015]],
#
# [[-0.13336606, -0.01862637, 0.06959272],
# [-0.02236020, -0.27091628, -0.24532901],
# [ 0.27254242, 0.15516677, 0.09036587]],
#
# [[ 0.30169338, -0.28146112, -0.11768346],
# [-0.45765871, -0.12504843, -0.17482486],
# [-0.36866254, -0.19969313, 0.08783543]]]])

"""

if dim is None:
if isinstance(layer, (Conv1DTranspose, Conv2DTranspose, Conv3DTranspose,
Linear)):
dim = 1
else:
dim = 0
SpectralNorm.apply(layer, name, n_power_iterations, dim, eps)
return layer