From cdbfb33a7345bf7cfb9e4e15c5323da85063d3bb Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 7 Aug 2019 04:34:53 -0700 Subject: [PATCH] [Relay/TOPI][Op] Add variance and layer norm op (#3700) * Add LayerNorm op * update * fix * Add mean_std and mean_variance * add std and update doc * add license * x * lint * x * fix * fix doc --- docs/langref/relay_op.rst | 8 ++ include/tvm/relay/attrs/nn.h | 21 +++ include/tvm/relay/attrs/reduce.h | 64 +++++++++ python/tvm/relay/frontend/mxnet.py | 14 +- python/tvm/relay/op/_reduce.py | 1 + python/tvm/relay/op/nn/nn.py | 60 ++++++++- python/tvm/relay/op/reduce.py | 141 +++++++++++++++++++- python/tvm/relay/testing/dcgan.py | 2 +- python/tvm/relay/testing/inception_v3.py | 3 +- src/relay/op/nn/nn.cc | 47 +++++++ src/relay/op/tensor/reduce.cc | 109 +++++++++++---- src/relay/pass/pattern_util.h | 21 +++ src/relay/pass/simplify_inference.cc | 35 ++++- tests/python/frontend/mxnet/test_forward.py | 52 ++++++++ tests/python/relay/test_op_level4.py | 40 ++++++ 15 files changed, 583 insertions(+), 35 deletions(-) create mode 100644 include/tvm/relay/attrs/reduce.h diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 61c9b36e1ffd..757fdac32b81 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -150,6 +150,10 @@ This level enables additional math and transform operators. tvm.relay.max tvm.relay.min tvm.relay.mean + tvm.relay.variance + tvm.relay.std + tvm.relay.mean_variance + tvm.relay.mean_std tvm.relay.prod tvm.relay.strided_slice tvm.relay.broadcast_to @@ -297,6 +301,10 @@ Level 4 Definitions .. autofunction:: tvm.relay.max .. autofunction:: tvm.relay.min .. autofunction:: tvm.relay.mean +.. autofunction:: tvm.relay.variance +.. autofunction:: tvm.relay.std +.. autofunction:: tvm.relay.mean_variance +.. autofunction:: tvm.relay.mean_std .. autofunction:: tvm.relay.prod .. autofunction:: tvm.relay.strided_slice .. autofunction:: tvm.relay.broadcast_to diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 58c4bba30cbc..085ad3175d16 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -470,6 +470,27 @@ struct BatchNormAttrs : public tvm::AttrsNode { }; // struct BatchNormAttrs +/*! \brief Attributes used in layer_norm operator */ +struct LayerNormAttrs : public tvm::AttrsNode { + int axis; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(LayerNormAttrs, "relay.attrs.LayerNormAttrs") { + TVM_ATTR_FIELD(axis).set_default(-1) + .describe("Specify which shape axis denotes the channel."); + TVM_ATTR_FIELD(epsilon).set_default(1e-5) + .describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).set_default(true) + .describe("If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true) + .describe("If true, multiply by gamma; otherwise, gamma is ignored."); + } +}; // struct LayerNormAttrs + + /*! \brief Attributes for LRN operator */ struct LRNAttrs : public tvm::AttrsNode { int size; diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h new file mode 100644 index 000000000000..e86e89c161e3 --- /dev/null +++ b/include/tvm/relay/attrs/reduce.h @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/relay/attrs/reduce.h + * \brief Auxiliary attributes for reduce operators. + */ +#ifndef TVM_RELAY_ATTRS_REDUCE_H_ +#define TVM_RELAY_ATTRS_REDUCE_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Attributes for Reduce operators */ +struct ReduceAttrs : public tvm::AttrsNode { + Array axis; + bool keepdims; + bool exclude; + + TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { + TVM_ATTR_FIELD(axis).set_default(NullValue>()) + .describe(R"code(The axis or axes along which to perform the reduction. + + The default, `axis=()`, will compute over all elements into a + scalar array with shape `(1,)`. + + If `axis` is int, a reduction is performed on a particular axis. + + If `axis` is a tuple of ints, a reduction is performed on all the axes + specified in the tuple. + + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead.)code"); + + TVM_ATTR_FIELD(keepdims).set_default(false) + .describe("If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + TVM_ATTR_FIELD(exclude).set_default(false) + .describe("Whether to perform reduction on axis that are NOT in axis instead."); + } +}; +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_REDUCE_H_ diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 3486263252b7..9d82671e5534 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -224,10 +224,21 @@ def _mx_batch_norm(inputs, attrs): new_attrs["axis"] = attrs.get_int("axis", 1) new_attrs["epsilon"] = attrs.get_float("eps", 0.001) new_attrs["center"] = True - new_attrs["scale"] = not attrs.get_bool("fix_gamma", False) + new_attrs["scale"] = not attrs.get_bool("fix_gamma", True) return _op.nn.batch_norm(*inputs, **new_attrs) +def _mx_layer_norm(inputs, attrs): + assert len(inputs) == 3 + if attrs.get_bool("output_mean_var", False): + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "output_mean_var" is not supported for operator Layer Norm.') + new_attrs = {} + new_attrs["axis"] = attrs.get_int("axis", -1) + new_attrs["epsilon"] = attrs.get_float("eps", 1e-5) + return _op.nn.layer_norm(*inputs, **new_attrs) + + def _mx_slice(inputs, attrs): new_attrs = {} begin = attrs.get_int_tuple('begin', None) @@ -997,6 +1008,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "Dropout" : _mx_dropout, "BatchNorm" : _mx_batch_norm, "BatchNorm_v1" : _mx_batch_norm, + "LayerNorm" : _mx_layer_norm, "LRN" : _mx_lrn, "L2Normalization" : _mx_l2_normalize, "slice" : _mx_slice, diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py index b7c9a79a8ad9..b6c05b1077d2 100644 --- a/python/tvm/relay/op/_reduce.py +++ b/python/tvm/relay/op/_reduce.py @@ -35,3 +35,4 @@ def _schedule_reduce(_, outs, target): _reg.register_schedule("min", _schedule_reduce) _reg.register_schedule("prod", _schedule_reduce) _reg.register_schedule("mean", _schedule_reduce) +_reg.register_schedule("variance", _schedule_reduce) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 4a83ef233c24..229689d9e548 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -867,7 +867,7 @@ def batch_norm(data, Specify along which shape axis the channel is specified. epsilon : double, optional, default=1e-5 - Small float added to variance to avoid diving by zero. + Small float added to variance to avoid dividing by zero. center : boolean, optional, default=True If True, add offset of beta to normalized tensor, If False, @@ -897,6 +897,64 @@ def batch_norm(data, return TupleWrapper(result, 3) +def layer_norm(data, + gamma, + beta, + axis=-1, + epsilon=1e-5, + center=True, + scale=True): + r""" + Layer normalization (Lei Ba and et al., 2016). + Applies layer normalization to the n-dimensional input array. + This operator takes an n-dimensional input array and normalizes + the input using the given axis: + + .. math:: + + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}} + * gamma + beta + + Unlike batch normalization, the mean and var are computed along the channel dimension. + + Assume the input has size k on axis 1, then both gamma and beta have shape (k,). + + .. note:: + + This operator can be optimized away for inference. + + Parameters + ---------- + data : tvm.relay.Expr + Input to which batch_norm will be applied. + + gamma : tvm.relay.Expr + The gamma scale factor. + + beta : tvm.relay.Expr + The beta offset factor. + + axis : int, optional, default=-1 + The axis that should be normalized, typically the axis of the channels. + + epsilon : double, optional, default=1e-5 + Small float added to variance to avoid dividing by zero. + + center : boolean, optional, default=True + If True, add offset of beta to normalized tensor, If False, + beta is ignored. + + scale : boolean, optional, default=True + If True, multiply by gamma. If False, gamma is not used. + + Returns + ------- + result : tvm.relay.Expr + The normalized data. + """ + return _make.layer_norm(data, gamma, beta, axis, epsilon, center, scale) + + def batch_matmul(x, y): r""" Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 41e1fc041cce..49193fd4b5c6 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -18,6 +18,9 @@ # pylint: disable=redefined-builtin from . import _make +from .tensor import sqrt +from .transform import squeeze +from ..expr import Tuple, TupleWrapper def argmax(data, axis=None, keepdims=False, exclude=False): """Returns the indices of the maximum values along an axis. @@ -236,8 +239,8 @@ def mean(data, axis=None, keepdims=False, exclude=False): axis : None or int or tuple of int Axis or axes along which a mean operation is performed. - The default, axis=None, will find the indices of minimum element all of the elements of - the input array. If axis is negative it counts from the last to the first axis. + The default, axis=None, will compute the mean of all elements in the input array. + If axis is negative it counts from the last to the first axis. keepdims : bool If this is set to True, the axes which are reduced are left in the result as dimensions @@ -257,6 +260,140 @@ def mean(data, axis=None, keepdims=False, exclude=False): return _make.mean(data, axis, keepdims, exclude) +def variance(data, axis=None, keepdims=False, exclude=False): + """Computes the variance of data over given axes. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a variance operation is performed. + The default, axis=None, will compute the variance of all elements in the input array. + If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + axis = [axis] if isinstance(axis, int) else axis + m = mean(data, axis, True, exclude) + return _make._variance(data, m, axis, keepdims, exclude) + + +def std(data, axis=None, keepdims=False, exclude=False): + """Computes the standard deviation of data over given axes. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a standard deviation operation is performed. + The default, axis=None, will compute the standard deviation of all elements in the + input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + axis = [axis] if isinstance(axis, int) else axis + m = mean(data, axis, True, exclude) + return sqrt(_make._variance(data, m, axis, keepdims, exclude)) + + +def mean_variance(data, axis=None, keepdims=False, exclude=False): + """Computes the mean and variance of data over given axes. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a mean and variance operation is performed. + The default, axis=None, will compute the mean and variance of all elements in + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + axis = [axis] if isinstance(axis, int) else axis + m = mean(data, axis, True, exclude) + var = _make._variance(data, m, axis, keepdims, exclude) + if not keepdims: + m = squeeze(m) + return TupleWrapper(Tuple((m, var)), 2) + + +def mean_std(data, axis=None, keepdims=False, exclude=False): + """Computes the mean and standard deviation of data over given axes. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a mean and standard deviation operation is performed. + The default, axis=None, will compute the mean and standard deviation of all elements in + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + axis = [axis] if isinstance(axis, int) else axis + m = mean(data, axis, True, exclude) + s = sqrt(_make._variance(data, m, axis, keepdims, exclude)) + if not keepdims: + m = squeeze(m) + return TupleWrapper(Tuple((m, s)), 2) + + def prod(data, axis=None, keepdims=False, exclude=False): """Computes the products of array elements over given axes. diff --git a/python/tvm/relay/testing/dcgan.py b/python/tvm/relay/testing/dcgan.py index c6b258badb5b..6907eb01c88c 100644 --- a/python/tvm/relay/testing/dcgan.py +++ b/python/tvm/relay/testing/dcgan.py @@ -52,7 +52,7 @@ def deconv2d_bn_relu(data, prefix, **kwargs): """a block of deconv + batch norm + relu""" eps = 1e-5 + 1e-12 net = deconv2d(data, name="%s_deconv" % prefix, **kwargs) - net = layers.batch_norm_infer(net, epsilon=eps, name="%s_batch_norm" % prefix) + net = layers.batch_norm_infer(net, epsilon=eps, scale=False, name="%s_batch_norm" % prefix) net = relay.nn.relu(net) return net diff --git a/python/tvm/relay/testing/inception_v3.py b/python/tvm/relay/testing/inception_v3.py index 4da543257c31..fa4233d67b31 100644 --- a/python/tvm/relay/testing/inception_v3.py +++ b/python/tvm/relay/testing/inception_v3.py @@ -38,7 +38,8 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, padding=pad, name='%s%s_conv1' % (name, suffix)) - bn = layers.batch_norm_infer(data=conv, epsilon=2e-5, name='%s%s_bn' % (name, suffix)) + bn = layers.batch_norm_infer(data=conv, epsilon=2e-5, scale=False, + name='%s%s_bn' % (name, suffix)) act = relay.nn.relu(data=bn) return act diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index c0f36bfa2915..2c03bbac70d7 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -678,6 +678,53 @@ axis to be the last item in the input shape. .add_type_rel("BatchNorm", BatchNormRel); +// layer_norm +TVM_REGISTER_NODE_TYPE(LayerNormAttrs); + +bool LayerNormRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + if (data == nullptr) return false; + const LayerNormAttrs* param = attrs.as(); + int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size(); + CHECK(axis >= 0 && axis < (int)data->shape.size()); + reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype)); + reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype)); + reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype)); + + return true; +} + +Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, + bool center, bool scale) { + auto attrs = make_node(); + attrs->axis = axis; + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + static const Op& op = Op::Get("nn.layer_norm"); + return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.layer_norm") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeLayerNorm, args, rv); + }); + +RELAY_REGISTER_OP("nn.layer_norm") +.describe(R"code( +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.LayerNormAttrs") +.set_num_inputs(3) +.add_argument("data", "Tensor", "Input to which layer_norm will be applied.") +.add_argument("gamma", "Tensor", "The gamma scale factor.") +.add_argument("beta", "Tensor", "The beta offset factor.") +.set_support_level(1) +.add_type_rel("LayerNorm", LayerNormRel); + // relay.nn.batch_matmul bool BatchMatmulRel(const Array& types, int num_inputs, diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index d655665f2083..a7be3ffa127f 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -24,6 +24,7 @@ */ #include #include +#include #include #include #include @@ -34,34 +35,7 @@ namespace tvm { namespace relay { -/*! \brief Attributes for Reduce operators */ -struct ReduceAttrs : public tvm::AttrsNode { - Array axis; - bool keepdims; - bool exclude; - - TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue>()) - .describe(R"code(The axis or axes along which to perform the reduction. - - The default, `axis=()`, will compute over all elements into a - scalar array with shape `(1,)`. - - If `axis` is int, a reduction is performed on a particular axis. - - If `axis` is a tuple of ints, a reduction is performed on all the axes - specified in the tuple. - - If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead.)code"); - - TVM_ATTR_FIELD(keepdims).set_default(false) - .describe("If this is set to `True`, the reduced axes are left " - "in the result as dimension with size one."); - TVM_ATTR_FIELD(exclude).set_default(false) - .describe("Whether to perform reduction on axis that are NOT in axis instead."); - } -}; +TVM_REGISTER_NODE_TYPE(ReduceAttrs); /*! * \brief GetReduceAxes, get the new axis from indim and other arguments @@ -498,5 +472,84 @@ Example:: .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", MeanCompute) .set_attr("TOpPattern", kCommReduce); + + +bool VarianceRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + CHECK(static_cast(data->shape.size()) != 0); + const auto* mean = types[1].as(); + if (mean == nullptr) return false; + + std::vector in_shape(data->shape.begin(), data->shape.end()); + std::vector mean_shape(mean->shape.begin(), mean->shape.end()); + CHECK_EQ(in_shape.size(), mean_shape.size()); + + const ReduceAttrs* param = attrs.as(); + CHECK(param != nullptr); + + // assign output type and shape + auto oshape = ReduceShapeImpl(in_shape, param, reporter); + reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +Array VarianceCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + IndexExpr count = make_const(inputs[0]->dtype, 1); + const ReduceAttrs* param = attrs.as(); + CHECK(param != nullptr); + auto axes = param->axis; + auto data = inputs[0]; + auto mean = inputs[1]; + for (int64_t i : GetReduceAxes(data->shape.size(), + param->axis, + param->exclude)) { + count *= data->shape[i]; + } + std::vector expand_shape; + auto sq_diff = topi::power(topi::subtract(data, mean), 2); + auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, target, topi::sum)[0], count); + + return {var}; +} + +Expr MakeVariance(Expr data, + Expr mean, + Array axis, + bool keepdims, + bool exclude) { + auto attrs = make_node(); + attrs->axis = std::move(axis); + attrs->keepdims = keepdims; + attrs->exclude = exclude; + static const Op& op = Op::Get("variance"); + return CallNode::make(op, {data, mean}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make._variance") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeVariance, args, rv); +}); + +RELAY_REGISTER_OP("variance") +.describe(R"code(Computes the variance of array elements over given axes. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.ReduceAttrs") +.set_support_level(4) +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("mean", "Tensor", "The mean tensor.") +.add_type_rel("Variance", VarianceRel) +.set_attr("FTVMCompute", VarianceCompute) +.set_attr("TOpPattern", kCommReduce); + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 5c303905968e..7dcfd5cb4b7f 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -33,7 +33,9 @@ #include #include #include +#include #include +#include namespace tvm { @@ -373,6 +375,25 @@ inline Expr Copy(Expr data) { } +inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { + auto attrs = make_node(); + attrs->axis = std::move(axis); + attrs->keepdims = keepdims; + attrs->exclude = exclude; + static const Op& op = Op::Get("mean"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude) { + auto attrs = make_node(); + attrs->axis = std::move(axis); + attrs->keepdims = keepdims; + attrs->exclude = exclude; + static const Op& op = Op::Get("variance"); + return CallNode::make(op, {data, mean}, Attrs(attrs), {}); +} + + Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index daf48c44173e..3790dbf877f9 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include "./pattern_util.h" namespace tvm { @@ -54,8 +55,8 @@ Expr BatchNormToInferUnpack(const Attrs attrs, shift = Add(shift, beta); } - int axis = param->axis; auto ndim = ttype->shape.size(); + int axis = (param->axis < 0) ? param->axis + ndim : param->axis; scale = ExpandBiasToMatchAxis(scale, ndim, {axis}); shift = ExpandBiasToMatchAxis(shift, ndim, {axis}); @@ -64,6 +65,33 @@ Expr BatchNormToInferUnpack(const Attrs attrs, return out; } +Expr LayerNormToInferUnpack(const Attrs attrs, + Expr data, + Expr gamma, + Expr beta, + Type tdata) { + auto ttype = tdata.as(); + CHECK(ttype); + const auto param = attrs.as(); + CHECK(param); + + Expr epsilon = MakeConstantScalar(Float(32), static_cast(param->epsilon)); + Expr mean = Mean(data, {param->axis}, true, false); + Expr var = Variance(data, mean, {param->axis}, true, false); + Expr denom = Sqrt(Add(var, epsilon)); + Expr out = Divide(Subtract(data, mean), denom); + + size_t ndim = ttype->shape.size(); + int axis = (param->axis < 0) ? param->axis + ndim : param->axis; + if (param->scale) { + out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); + } + if (param->center) { + out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); + } + return out; +} + class InferenceSimplifier : public ExprMutator { public: Expr VisitExpr_(const TupleGetItemNode* n) final { @@ -88,9 +116,14 @@ class InferenceSimplifier : public ExprMutator { Expr VisitExpr_(const CallNode* n) { static const Op& batch_norm = Op::Get("nn.batch_norm"); + static const Op& layer_norm = Op::Get("nn.layer_norm"); auto new_n = ExprMutator::VisitExpr_(n); if (n->op.same_as(batch_norm)) { ty_map_[new_n.as()->args[0]] = n->args[0]->checked_type(); + } else if (n->op.same_as(layer_norm)) { + const auto* call = new_n.as(); + return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], + call->args[2], n->args[0]->checked_type()); } return new_n; } diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 451679cf9e19..a4a514ea7474 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -728,6 +728,56 @@ def verify(shape): verify((3, 4)) verify((3, 4, 5)) +def test_forward_batch_norm(): + def verify(shape, axis=1, fix_gamma=False): + x = np.random.uniform(size=shape).astype("float32") + gamma = np.random.uniform(size=(shape[axis])).astype("float32") + beta = np.random.uniform(size=(shape[axis])).astype("float32") + moving_mean = np.random.uniform(size=(shape[axis])).astype("float32") + moving_var = np.random.uniform(size=(shape[axis])).astype("float32") + ref_res = mx.nd.BatchNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), + mx.nd.array(moving_mean), mx.nd.array(moving_var), + axis=axis, use_global_stats=True, fix_gamma=fix_gamma) + mx_sym = mx.sym.BatchNorm(mx.sym.var("x"), mx.sym.var("gamma"), + mx.sym.var("beta"), mx.sym.var("mean"), + mx.sym.var("var"), axis=axis, use_global_stats=True, + fix_gamma=fix_gamma) + + shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape, + "mean": moving_mean.shape, "var": moving_var.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + #print(mod) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x, gamma, beta, moving_mean, moving_var) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) + verify((2, 3, 4, 5)) + verify((2, 3, 4, 5), axis=0) + verify((2, 3, 4, 5), axis=-1) + verify((2, 3, 4, 5), fix_gamma=True) + + +def test_forward_layer_norm(): + def verify(shape, axis=-1): + x = np.random.uniform(size=shape).astype("float32") + gamma = np.random.uniform(size=(shape[axis])).astype("float32") + beta = np.random.uniform(size=(shape[axis])).astype("float32") + ref_res = mx.nd.LayerNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), + axis=axis) + mx_sym = mx.sym.LayerNorm(mx.sym.var("x"), mx.sym.var("gamma"), + mx.sym.var("beta"), axis=axis) + shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x, gamma, beta) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) + verify((2, 5)) + verify((2, 5), axis=0) + verify((2, 5, 6)) + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -773,3 +823,5 @@ def verify(shape): test_forward_topk() test_forward_sequence_mask() test_forward_contrib_div_sqrt_dim() + test_forward_batch_norm() + test_forward_layer_norm() diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 69fd88b562b7..c34dddfd0fd7 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -203,6 +203,8 @@ def _wrapper(data, axis=None, keepdims=False): [relay.max, np.max], [relay.min, np.min], [relay.mean, np.mean], + [relay.variance, np.var], + [relay.std, np.std], [relay.prod, np.prod], [relay.all, np.all], [relay.argmin, _with_keepdims(np.argmin)], @@ -226,6 +228,43 @@ def _wrapper(data, axis=None, keepdims=False): verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) +def verify_mean_var_std(funcs, shape, axis, keepdims): + test_func = funcs[0] + ref_func = funcs[1] + dtype = "float32" + + x = relay.var("x", relay.TensorType(shape, dtype)) + z = test_func(x, axis, keepdims) + func = relay.Function([x], z.astuple()) + x_data = np.random.uniform(size=shape).astype(dtype) + ref_mean = np.mean(x_data, axis=axis, dtype=dtype, keepdims=keepdims) + ref_res = ref_func(x_data, axis=axis, dtype=dtype, keepdims=keepdims) + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res1[0].asnumpy(), ref_mean, rtol=1e-5) + tvm.testing.assert_allclose(op_res1[1].asnumpy(), ref_res, rtol=1e-5) + op_res2 = intrp2.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res2[0].asnumpy(), ref_mean, rtol=1e-5) + tvm.testing.assert_allclose(op_res2[1].asnumpy(), ref_res, rtol=1e-5) + +def test_mean_var_std(): + for func in [[relay.mean_variance, np.var], + [relay.mean_std, np.std]]: + verify_mean_var_std(func, (2, 3, 4), 1, True) + verify_mean_var_std(func, (2, 3, 4), (1,), True) + verify_mean_var_std(func, (2, 3, 4), -1, True) + verify_mean_var_std(func, (2, 3, 4), (0, 1, 2), False) + verify_mean_var_std(func, (4, 4, 3), None, False) + verify_mean_var_std(func, (4, 4, 3), (0, 2), False) + verify_mean_var_std(func, (128, 24, 128), (0, 1), False) + verify_mean_var_std(func, (128, 24, 128), (0, 2), False) + verify_mean_var_std(func, (128, 24, 128), (0, 1), True) + verify_mean_var_std(func, (128, 24, 128), (0, 2), True) + + def test_strided_slice(): def verify(dshape, begin, end, strides, output, test_ref=True): x = relay.var("x", relay.TensorType(dshape, "float32")) @@ -267,3 +306,4 @@ def verify(dshape, begin, end, strides, output, test_ref=True): test_binary_int_broadcast() test_where() test_reduce_functions() + test_mean_var_std()