Skip to content

Commit

Permalink
[Relay/TOPI][Op] Add variance and layer norm op (apache#3700)
Browse files Browse the repository at this point in the history
* Add LayerNorm op

* update

* fix

* Add mean_std and mean_variance

* add std and update doc

* add license

* x

* lint

* x

* fix

* fix doc
  • Loading branch information
icemelon authored and wweic committed Sep 6, 2019
1 parent d3f965b commit cdbfb33
Show file tree
Hide file tree
Showing 15 changed files with 583 additions and 35 deletions.
8 changes: 8 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ This level enables additional math and transform operators.
tvm.relay.max
tvm.relay.min
tvm.relay.mean
tvm.relay.variance
tvm.relay.std
tvm.relay.mean_variance
tvm.relay.mean_std
tvm.relay.prod
tvm.relay.strided_slice
tvm.relay.broadcast_to
Expand Down Expand Up @@ -297,6 +301,10 @@ Level 4 Definitions
.. autofunction:: tvm.relay.max
.. autofunction:: tvm.relay.min
.. autofunction:: tvm.relay.mean
.. autofunction:: tvm.relay.variance
.. autofunction:: tvm.relay.std
.. autofunction:: tvm.relay.mean_variance
.. autofunction:: tvm.relay.mean_std
.. autofunction:: tvm.relay.prod
.. autofunction:: tvm.relay.strided_slice
.. autofunction:: tvm.relay.broadcast_to
Expand Down
21 changes: 21 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,27 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
}; // struct BatchNormAttrs


/*! \brief Attributes used in layer_norm operator */
struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
int axis;
double epsilon;
bool center;
bool scale;

TVM_DECLARE_ATTRS(LayerNormAttrs, "relay.attrs.LayerNormAttrs") {
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Specify which shape axis denotes the channel.");
TVM_ATTR_FIELD(epsilon).set_default(1e-5)
.describe("Small float added to variance to avoid dividing by zero");
TVM_ATTR_FIELD(center).set_default(true)
.describe("If true, add offset of beta to normalized tensor; "
"otherwise, beta is ignored.");
TVM_ATTR_FIELD(scale).set_default(true)
.describe("If true, multiply by gamma; otherwise, gamma is ignored.");
}
}; // struct LayerNormAttrs


/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
int size;
Expand Down
64 changes: 64 additions & 0 deletions include/tvm/relay/attrs/reduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file tvm/relay/attrs/reduce.h
* \brief Auxiliary attributes for reduce operators.
*/
#ifndef TVM_RELAY_ATTRS_REDUCE_H_
#define TVM_RELAY_ATTRS_REDUCE_H_

#include <tvm/attrs.h>
#include <string>

namespace tvm {
namespace relay {

/*! \brief Attributes for Reduce operators */
struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
Array<Integer> axis;
bool keepdims;
bool exclude;

TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<Array<Integer>>())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.
If `axis` is int, a reduction is performed on a particular axis.
If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code");

TVM_ATTR_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
TVM_ATTR_FIELD(exclude).set_default(false)
.describe("Whether to perform reduction on axis that are NOT in axis instead.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_REDUCE_H_
14 changes: 13 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,21 @@ def _mx_batch_norm(inputs, attrs):
new_attrs["axis"] = attrs.get_int("axis", 1)
new_attrs["epsilon"] = attrs.get_float("eps", 0.001)
new_attrs["center"] = True
new_attrs["scale"] = not attrs.get_bool("fix_gamma", False)
new_attrs["scale"] = not attrs.get_bool("fix_gamma", True)
return _op.nn.batch_norm(*inputs, **new_attrs)


def _mx_layer_norm(inputs, attrs):
assert len(inputs) == 3
if attrs.get_bool("output_mean_var", False):
raise tvm.error.OpAttributeUnimplemented(
'Attribute "output_mean_var" is not supported for operator Layer Norm.')
new_attrs = {}
new_attrs["axis"] = attrs.get_int("axis", -1)
new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
return _op.nn.layer_norm(*inputs, **new_attrs)


def _mx_slice(inputs, attrs):
new_attrs = {}
begin = attrs.get_int_tuple('begin', None)
Expand Down Expand Up @@ -997,6 +1008,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
"Dropout" : _mx_dropout,
"BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm,
"LayerNorm" : _mx_layer_norm,
"LRN" : _mx_lrn,
"L2Normalization" : _mx_l2_normalize,
"slice" : _mx_slice,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ def _schedule_reduce(_, outs, target):
_reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce)
_reg.register_schedule("mean", _schedule_reduce)
_reg.register_schedule("variance", _schedule_reduce)
60 changes: 59 additions & 1 deletion python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def batch_norm(data,
Specify along which shape axis the channel is specified.
epsilon : double, optional, default=1e-5
Small float added to variance to avoid diving by zero.
Small float added to variance to avoid dividing by zero.
center : boolean, optional, default=True
If True, add offset of beta to normalized tensor, If False,
Expand Down Expand Up @@ -897,6 +897,64 @@ def batch_norm(data,
return TupleWrapper(result, 3)


def layer_norm(data,
gamma,
beta,
axis=-1,
epsilon=1e-5,
center=True,
scale=True):
r"""
Layer normalization (Lei Ba and et al., 2016).
Applies layer normalization to the n-dimensional input array.
This operator takes an n-dimensional input array and normalizes
the input using the given axis:
.. math::
out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
* gamma + beta
Unlike batch normalization, the mean and var are computed along the channel dimension.
Assume the input has size k on axis 1, then both gamma and beta have shape (k,).
.. note::
This operator can be optimized away for inference.
Parameters
----------
data : tvm.relay.Expr
Input to which batch_norm will be applied.
gamma : tvm.relay.Expr
The gamma scale factor.
beta : tvm.relay.Expr
The beta offset factor.
axis : int, optional, default=-1
The axis that should be normalized, typically the axis of the channels.
epsilon : double, optional, default=1e-5
Small float added to variance to avoid dividing by zero.
center : boolean, optional, default=True
If True, add offset of beta to normalized tensor, If False,
beta is ignored.
scale : boolean, optional, default=True
If True, multiply by gamma. If False, gamma is not used.
Returns
-------
result : tvm.relay.Expr
The normalized data.
"""
return _make.layer_norm(data, gamma, beta, axis, epsilon, center, scale)


def batch_matmul(x, y):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
Expand Down
141 changes: 139 additions & 2 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
# pylint: disable=redefined-builtin

from . import _make
from .tensor import sqrt
from .transform import squeeze
from ..expr import Tuple, TupleWrapper

def argmax(data, axis=None, keepdims=False, exclude=False):
"""Returns the indices of the maximum values along an axis.
Expand Down Expand Up @@ -236,8 +239,8 @@ def mean(data, axis=None, keepdims=False, exclude=False):
axis : None or int or tuple of int
Axis or axes along which a mean operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
The default, axis=None, will compute the mean of all elements in the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
Expand All @@ -257,6 +260,140 @@ def mean(data, axis=None, keepdims=False, exclude=False):
return _make.mean(data, axis, keepdims, exclude)


def variance(data, axis=None, keepdims=False, exclude=False):
"""Computes the variance of data over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a variance operation is performed.
The default, axis=None, will compute the variance of all elements in the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
return _make._variance(data, m, axis, keepdims, exclude)


def std(data, axis=None, keepdims=False, exclude=False):
"""Computes the standard deviation of data over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a standard deviation operation is performed.
The default, axis=None, will compute the standard deviation of all elements in the
input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
return sqrt(_make._variance(data, m, axis, keepdims, exclude))


def mean_variance(data, axis=None, keepdims=False, exclude=False):
"""Computes the mean and variance of data over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a mean and variance operation is performed.
The default, axis=None, will compute the mean and variance of all elements in
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
var = _make._variance(data, m, axis, keepdims, exclude)
if not keepdims:
m = squeeze(m)
return TupleWrapper(Tuple((m, var)), 2)


def mean_std(data, axis=None, keepdims=False, exclude=False):
"""Computes the mean and standard deviation of data over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a mean and standard deviation operation is performed.
The default, axis=None, will compute the mean and standard deviation of all elements in
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
s = sqrt(_make._variance(data, m, axis, keepdims, exclude))
if not keepdims:
m = squeeze(m)
return TupleWrapper(Tuple((m, s)), 2)


def prod(data, axis=None, keepdims=False, exclude=False):
"""Computes the products of array elements over given axes.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/testing/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def deconv2d_bn_relu(data, prefix, **kwargs):
"""a block of deconv + batch norm + relu"""
eps = 1e-5 + 1e-12
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
net = layers.batch_norm_infer(net, epsilon=eps, name="%s_batch_norm" % prefix)
net = layers.batch_norm_infer(net, epsilon=eps, scale=False, name="%s_batch_norm" % prefix)
net = relay.nn.relu(net)
return net

Expand Down
Loading

0 comments on commit cdbfb33

Please sign in to comment.