Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Op] Add instance norm op #4004

Merged
merged 2 commits into from
Oct 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,29 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
}; // struct BatchNormAttrs


/*! \brief Attributes used in instance_norm operator */
struct InstanceNormAttrs : public tvm::AttrsNode<InstanceNormAttrs> {
int axis;
double epsilon;
bool center;
bool scale;

TVM_DECLARE_ATTRS(InstanceNormAttrs, "relay.attrs.InstanceNormAttrs") {
TVM_ATTR_FIELD(axis)
.describe("Specify which shape axis denotes the channel.")
.set_default(1);
TVM_ATTR_FIELD(epsilon)
.describe("Small float added to variance to avoid dividing by zero")
.set_default(1e-5);
TVM_ATTR_FIELD(center).set_default(true)
.describe("If true, add offset of beta to normalized tensor; "
"otherwise, beta is ignored.");
TVM_ATTR_FIELD(scale).set_default(true)
.describe("If true, multiply by gamma; otherwise, gamma is ignored.");
}
}; // struct InstanceNormAttrs


/*! \brief Attributes used in layer_norm operator */
struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
int axis;
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,14 @@ def _mx_batch_norm(inputs, attrs):
return _op.nn.batch_norm(*inputs, **new_attrs)


def _mx_instance_norm(inputs, attrs):
assert len(inputs) == 3
new_attrs = {}
new_attrs["axis"] = attrs.get_int("axis", 1)
new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
return _op.nn.instance_norm(*inputs, **new_attrs)


def _mx_layer_norm(inputs, attrs):
assert len(inputs) == 3
if attrs.get_bool("output_mean_var", False):
Expand Down Expand Up @@ -1133,6 +1141,7 @@ def _mx_one_hot(inputs, attrs):
"Dropout" : _mx_dropout,
"BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm,
"InstanceNorm" : _mx_instance_norm,
"LayerNorm" : _mx_layer_norm,
"LRN" : _mx_lrn,
"L2Normalization" : _mx_l2_normalize,
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ def _impl_v1(cls, inputs, attr, params):
return out[0]


class InstanceNorm(OnnxOpConverter):
""" Operator converter for BatchNorm.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
return AttrCvt(op_name='instance_norm')(inputs, attr, params)


class Conv(OnnxOpConverter):
""" Operator converter for Conv.
"""
Expand Down Expand Up @@ -989,7 +998,7 @@ def _get_convert_map(opset):
'GlobalAveragePool': Renamer('global_avg_pool2d'),
'GlobalMaxPool': Renamer('global_max_pool2d'),
'BatchNormalization': BatchNorm.get_converter(opset),
# 'InstanceNormalization'
'InstanceNormalization': InstanceNorm.get_converter(opset),
# 'LpNormalization'
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten': Flatten.get_converter(opset),
Expand Down
69 changes: 68 additions & 1 deletion python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,73 @@ def batch_norm(data,
return TupleWrapper(result, 3)


def instance_norm(data,
gamma,
beta,
axis=1,
epsilon=1e-5,
center=True,
scale=True):
r"""
Instance Normalization (Ulyanov and et al., 2016)
Applies instance normalization to the n-dimensional input array.

.. math::

out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
* gamma + beta

The instance normalization is similar to batch normalization, but unlike
batch normalization, the mean and var are calculated per-dimension
separately for each object(instance) in a mini-batch, not over a batch.
And the same normalization is applied both at test and train time.

Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*.

The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel'. The default is 1. Specifying -1 sets the channel axis
to be the last item in the input shape.

.. note::

This operator can be optimized away for inference.

Parameters
----------
data : tvm.relay.Expr
Input to which instance_norm will be applied.

gamma : tvm.relay.Expr
The gamma scale factor.

beta : tvm.relay.Expr
The beta offset factor.

axis : int, optional, default=1
Specify along which shape axis the channel is specified.

epsilon : double, optional, default=1e-5
Small float added to variance to avoid dividing by zero.

center : boolean, optional, default=True
If True, add offset of beta to normalized tensor, If False,
beta is ignored.

scale : boolean, optional, default=True
If True, multiply by gamma. If False, gamma is not used.

Returns
-------
result : tvm.relay.Expr
The normalized data.

.. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
https://arxiv.org/abs/1607.08022
"""
return _make.instance_norm(data, gamma, beta, axis, epsilon, center, scale)


def layer_norm(data,
gamma,
beta,
Expand Down Expand Up @@ -964,7 +1031,7 @@ def layer_norm(data,
Parameters
----------
data : tvm.relay.Expr
Input to which batch_norm will be applied.
Input to which layer_norm will be applied.

gamma : tvm.relay.Expr
The gamma scale factor.
Expand Down
70 changes: 70 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,76 @@ axis to be the last item in the input shape.
.add_type_rel("BatchNorm", BatchNormRel);


// instance_norm
TVM_REGISTER_NODE_TYPE(InstanceNormAttrs);

bool InstanceNormRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
CHECK(axis >= 0 && axis < (int)data->shape.size());
reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));

return true;
}

Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon,
bool center, bool scale) {
auto attrs = make_node<InstanceNormAttrs>();
attrs->axis = axis;
attrs->epsilon = epsilon;
attrs->center = center;
attrs->scale = scale;
static const Op& op = Op::Get("nn.instance_norm");
return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.instance_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 7>(MakeInstanceNorm, args, rv);
});

RELAY_REGISTER_OP("nn.instance_norm")
.describe(R"code(Instance Normalization (Ulyanov and et al., 2016)
Applies instance normalization to the n-dimensional input array.

.. math::

out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
* gamma + beta

The instance normalization is similar to batch normalization, but unlike
batch normalization, the mean and var are calculated per-dimension
separately for each object(instance) in a mini-batch, not over a batch.
And the same normalization is applied both at test and train time.

Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*.

The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel'. The default is 1. Specifying -1 sets the channel axis
to be the last item in the input shape.

.. note::

This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InstanceNormAttrs")
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which instance_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_support_level(1)
.add_type_rel("InstanceNorm", InstanceNormRel);


// layer_norm
TVM_REGISTER_NODE_TYPE(LayerNormAttrs);

Expand Down
40 changes: 40 additions & 0 deletions src/relay/pass/simplify_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,41 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
return out;
}


Expr InstanceNormToInferUnpack(const Attrs attrs,
Expr data,
Expr gamma,
Expr beta,
Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<InstanceNormAttrs>();
CHECK(param);

int ndim = ttype->shape.size();
int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
Array<Integer> reduced_axes;
for (int i = 1; i < ndim; ++i) {
if (i != axis)
reduced_axes.push_back(i);
}

Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
Expr mean = Mean(data, reduced_axes, true, false);
Expr var = Variance(data, mean, reduced_axes, true, false);
Expr denom = Sqrt(Add(var, epsilon));
Expr out = Divide(Subtract(data, mean), denom);

if (param->scale) {
out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
}
if (param->center) {
out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
}
return out;
}


class InferenceSimplifier : public ExprMutator {
public:
Expr VisitExpr_(const TupleGetItemNode* n) final {
Expand All @@ -116,6 +151,7 @@ class InferenceSimplifier : public ExprMutator {

Expr VisitExpr_(const CallNode* n) {
static const Op& batch_norm = Op::Get("nn.batch_norm");
static const Op& instance_norm = Op::Get("nn.instance_norm");
static const Op& layer_norm = Op::Get("nn.layer_norm");
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(batch_norm)) {
Expand All @@ -124,6 +160,10 @@ class InferenceSimplifier : public ExprMutator {
const auto* call = new_n.as<CallNode>();
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
} else if (n->op.same_as(instance_norm)) {
const auto* call = new_n.as<CallNode>();
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
}
return new_n;
}
Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,26 @@ def verify(shape, axis=1, fix_gamma=False):
verify((2, 3, 4, 5), fix_gamma=True)


def test_forward_instance_norm():
def verify(shape, axis=1, epsilon=1e-5):
x = np.random.uniform(size=shape).astype("float32")
gamma = np.random.uniform(size=(shape[axis])).astype("float32")
beta = np.random.uniform(size=(shape[axis])).astype("float32")
ref_res = mx.nd.InstanceNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), epsilon)
mx_sym = mx.sym.InstanceNorm(mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), epsilon)
shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, gamma, beta)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
verify((2, 3, 4, 5))
verify((32, 64, 80, 64))
verify((8, 6, 5))
verify((8, 7, 6, 5, 4))


def test_forward_layer_norm():
def verify(shape, axis=-1):
x = np.random.uniform(size=shape).astype("float32")
Expand Down Expand Up @@ -926,6 +946,7 @@ def verify(data_shape, kernel_size, stride, pad, num_filter):
test_forward_sequence_mask()
test_forward_contrib_div_sqrt_dim()
test_forward_batch_norm()
test_forward_instance_norm()
test_forward_layer_norm()
test_forward_one_hot()
test_forward_convolution()
Expand Down
45 changes: 45 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,50 @@ def test_lrn():
verify_lrn((5, 5, 5, 5), 3, 'float32')
verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0)


def verify_instance_norm(shape, axis=1):

def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5):
dims_x = len(x.shape)
axis = tuple(range(2, dims_x))
mean = np.mean(x, axis=axis, keepdims=True)
var = np.var(x, axis=axis, keepdims=True)
dim_ones = (1,) * (dims_x - 2)
gamma = gamma.reshape(-1, *dim_ones)
beta = beta.reshape(-1, *dim_ones)
return gamma * (x - mean) / np.sqrt(var + epsilon) + beta

x = np.random.randn(*shape).astype(np.float32)
gamma = np.random.randn(shape[1]).astype(np.float32)
beta = np.random.randn(shape[1]).astype(np.float32)
epsilon = 1e-5
y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32)

node = onnx.helper.make_node(
'InstanceNormalization',
inputs=['x', 'gamma', 'beta'],
outputs=['y'],
epsilon=epsilon,
)
graph = helper.make_graph([node],
"instance_norm_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)),
helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)),
helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))])
model = helper.make_model(graph, producer_name='instance_norm_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, 'float32')
tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5)


def test_instance_norm():
verify_instance_norm((2, 3, 4, 5))
verify_instance_norm((32, 64, 80, 64))
verify_instance_norm((8, 6, 5))
verify_instance_norm((8, 7, 6, 5, 4))


def _test_upsample_nearest():
scale = 2
in_shape = (1, 1, 3, 3)
Expand Down Expand Up @@ -1258,6 +1302,7 @@ def test_erf():
test_matmul()
test_gather()
test_lrn()
test_instance_norm()
test_upsample()
test_forward_min()
test_forward_max()
Expand Down