Skip to content

Commit

Permalink
[Relay][Op] Add instance norm op (#4004)
Browse files Browse the repository at this point in the history
* [Relay][Op] Add instance norm op

* mend

[Relay][Op] Add instance norm op
  • Loading branch information
bindog authored and vinx13 committed Oct 3, 2019
1 parent 36201fe commit 7d911f4
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 2 deletions.
23 changes: 23 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,29 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
}; // struct BatchNormAttrs


/*! \brief Attributes used in instance_norm operator */
struct InstanceNormAttrs : public tvm::AttrsNode<InstanceNormAttrs> {
int axis;
double epsilon;
bool center;
bool scale;

TVM_DECLARE_ATTRS(InstanceNormAttrs, "relay.attrs.InstanceNormAttrs") {
TVM_ATTR_FIELD(axis)
.describe("Specify which shape axis denotes the channel.")
.set_default(1);
TVM_ATTR_FIELD(epsilon)
.describe("Small float added to variance to avoid dividing by zero")
.set_default(1e-5);
TVM_ATTR_FIELD(center).set_default(true)
.describe("If true, add offset of beta to normalized tensor; "
"otherwise, beta is ignored.");
TVM_ATTR_FIELD(scale).set_default(true)
.describe("If true, multiply by gamma; otherwise, gamma is ignored.");
}
}; // struct InstanceNormAttrs


/*! \brief Attributes used in layer_norm operator */
struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
int axis;
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,14 @@ def _mx_batch_norm(inputs, attrs):
return _op.nn.batch_norm(*inputs, **new_attrs)


def _mx_instance_norm(inputs, attrs):
assert len(inputs) == 3
new_attrs = {}
new_attrs["axis"] = attrs.get_int("axis", 1)
new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
return _op.nn.instance_norm(*inputs, **new_attrs)


def _mx_layer_norm(inputs, attrs):
assert len(inputs) == 3
if attrs.get_bool("output_mean_var", False):
Expand Down Expand Up @@ -1133,6 +1141,7 @@ def _mx_one_hot(inputs, attrs):
"Dropout" : _mx_dropout,
"BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm,
"InstanceNorm" : _mx_instance_norm,
"LayerNorm" : _mx_layer_norm,
"LRN" : _mx_lrn,
"L2Normalization" : _mx_l2_normalize,
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ def _impl_v1(cls, inputs, attr, params):
return out[0]


class InstanceNorm(OnnxOpConverter):
""" Operator converter for BatchNorm.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
return AttrCvt(op_name='instance_norm')(inputs, attr, params)


class Conv(OnnxOpConverter):
""" Operator converter for Conv.
"""
Expand Down Expand Up @@ -999,7 +1008,7 @@ def _get_convert_map(opset):
'GlobalAveragePool': Renamer('global_avg_pool2d'),
'GlobalMaxPool': Renamer('global_max_pool2d'),
'BatchNormalization': BatchNorm.get_converter(opset),
# 'InstanceNormalization'
'InstanceNormalization': InstanceNorm.get_converter(opset),
# 'LpNormalization'
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten': Flatten.get_converter(opset),
Expand Down
69 changes: 68 additions & 1 deletion python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,73 @@ def batch_norm(data,
return TupleWrapper(result, 3)


def instance_norm(data,
gamma,
beta,
axis=1,
epsilon=1e-5,
center=True,
scale=True):
r"""
Instance Normalization (Ulyanov and et al., 2016)
Applies instance normalization to the n-dimensional input array.
.. math::
out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
* gamma + beta
The instance normalization is similar to batch normalization, but unlike
batch normalization, the mean and var are calculated per-dimension
separately for each object(instance) in a mini-batch, not over a batch.
And the same normalization is applied both at test and train time.
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*.
The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel'. The default is 1. Specifying -1 sets the channel axis
to be the last item in the input shape.
.. note::
This operator can be optimized away for inference.
Parameters
----------
data : tvm.relay.Expr
Input to which instance_norm will be applied.
gamma : tvm.relay.Expr
The gamma scale factor.
beta : tvm.relay.Expr
The beta offset factor.
axis : int, optional, default=1
Specify along which shape axis the channel is specified.
epsilon : double, optional, default=1e-5
Small float added to variance to avoid dividing by zero.
center : boolean, optional, default=True
If True, add offset of beta to normalized tensor, If False,
beta is ignored.
scale : boolean, optional, default=True
If True, multiply by gamma. If False, gamma is not used.
Returns
-------
result : tvm.relay.Expr
The normalized data.
.. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
https://arxiv.org/abs/1607.08022
"""
return _make.instance_norm(data, gamma, beta, axis, epsilon, center, scale)


def layer_norm(data,
gamma,
beta,
Expand Down Expand Up @@ -964,7 +1031,7 @@ def layer_norm(data,
Parameters
----------
data : tvm.relay.Expr
Input to which batch_norm will be applied.
Input to which layer_norm will be applied.
gamma : tvm.relay.Expr
The gamma scale factor.
Expand Down
70 changes: 70 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,76 @@ axis to be the last item in the input shape.
.add_type_rel("BatchNorm", BatchNormRel);


// instance_norm
TVM_REGISTER_NODE_TYPE(InstanceNormAttrs);

bool InstanceNormRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
CHECK(axis >= 0 && axis < (int)data->shape.size());
reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));

return true;
}

Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon,
bool center, bool scale) {
auto attrs = make_node<InstanceNormAttrs>();
attrs->axis = axis;
attrs->epsilon = epsilon;
attrs->center = center;
attrs->scale = scale;
static const Op& op = Op::Get("nn.instance_norm");
return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.instance_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 7>(MakeInstanceNorm, args, rv);
});

RELAY_REGISTER_OP("nn.instance_norm")
.describe(R"code(Instance Normalization (Ulyanov and et al., 2016)
Applies instance normalization to the n-dimensional input array.
.. math::
out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
* gamma + beta
The instance normalization is similar to batch normalization, but unlike
batch normalization, the mean and var are calculated per-dimension
separately for each object(instance) in a mini-batch, not over a batch.
And the same normalization is applied both at test and train time.
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*.
The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel'. The default is 1. Specifying -1 sets the channel axis
to be the last item in the input shape.
.. note::
This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InstanceNormAttrs")
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which instance_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_support_level(1)
.add_type_rel("InstanceNorm", InstanceNormRel);


// layer_norm
TVM_REGISTER_NODE_TYPE(LayerNormAttrs);

Expand Down
40 changes: 40 additions & 0 deletions src/relay/pass/simplify_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,41 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
return out;
}


Expr InstanceNormToInferUnpack(const Attrs attrs,
Expr data,
Expr gamma,
Expr beta,
Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<InstanceNormAttrs>();
CHECK(param);

int ndim = ttype->shape.size();
int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
Array<Integer> reduced_axes;
for (int i = 1; i < ndim; ++i) {
if (i != axis)
reduced_axes.push_back(i);
}

Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
Expr mean = Mean(data, reduced_axes, true, false);
Expr var = Variance(data, mean, reduced_axes, true, false);
Expr denom = Sqrt(Add(var, epsilon));
Expr out = Divide(Subtract(data, mean), denom);

if (param->scale) {
out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
}
if (param->center) {
out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
}
return out;
}


class InferenceSimplifier : public ExprMutator {
public:
Expr VisitExpr_(const TupleGetItemNode* n) final {
Expand All @@ -116,6 +151,7 @@ class InferenceSimplifier : public ExprMutator {

Expr VisitExpr_(const CallNode* n) {
static const Op& batch_norm = Op::Get("nn.batch_norm");
static const Op& instance_norm = Op::Get("nn.instance_norm");
static const Op& layer_norm = Op::Get("nn.layer_norm");
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(batch_norm)) {
Expand All @@ -124,6 +160,10 @@ class InferenceSimplifier : public ExprMutator {
const auto* call = new_n.as<CallNode>();
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
} else if (n->op.same_as(instance_norm)) {
const auto* call = new_n.as<CallNode>();
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
}
return new_n;
}
Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,26 @@ def verify(shape, axis=1, fix_gamma=False):
verify((2, 3, 4, 5), fix_gamma=True)


def test_forward_instance_norm():
def verify(shape, axis=1, epsilon=1e-5):
x = np.random.uniform(size=shape).astype("float32")
gamma = np.random.uniform(size=(shape[axis])).astype("float32")
beta = np.random.uniform(size=(shape[axis])).astype("float32")
ref_res = mx.nd.InstanceNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), epsilon)
mx_sym = mx.sym.InstanceNorm(mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), epsilon)
shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, gamma, beta)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
verify((2, 3, 4, 5))
verify((32, 64, 80, 64))
verify((8, 6, 5))
verify((8, 7, 6, 5, 4))


def test_forward_layer_norm():
def verify(shape, axis=-1):
x = np.random.uniform(size=shape).astype("float32")
Expand Down Expand Up @@ -938,6 +958,7 @@ def verify(data_shape, kernel_size, stride, pad, num_filter):
test_forward_sequence_mask()
test_forward_contrib_div_sqrt_dim()
test_forward_batch_norm()
test_forward_instance_norm()
test_forward_layer_norm()
test_forward_one_hot()
test_forward_convolution()
Expand Down
45 changes: 45 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,50 @@ def test_lrn():
verify_lrn((5, 5, 5, 5), 3, 'float32')
verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0)


def verify_instance_norm(shape, axis=1):

def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5):
dims_x = len(x.shape)
axis = tuple(range(2, dims_x))
mean = np.mean(x, axis=axis, keepdims=True)
var = np.var(x, axis=axis, keepdims=True)
dim_ones = (1,) * (dims_x - 2)
gamma = gamma.reshape(-1, *dim_ones)
beta = beta.reshape(-1, *dim_ones)
return gamma * (x - mean) / np.sqrt(var + epsilon) + beta

x = np.random.randn(*shape).astype(np.float32)
gamma = np.random.randn(shape[1]).astype(np.float32)
beta = np.random.randn(shape[1]).astype(np.float32)
epsilon = 1e-5
y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32)

node = onnx.helper.make_node(
'InstanceNormalization',
inputs=['x', 'gamma', 'beta'],
outputs=['y'],
epsilon=epsilon,
)
graph = helper.make_graph([node],
"instance_norm_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)),
helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)),
helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))])
model = helper.make_model(graph, producer_name='instance_norm_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, 'float32')
tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5)


def test_instance_norm():
verify_instance_norm((2, 3, 4, 5))
verify_instance_norm((32, 64, 80, 64))
verify_instance_norm((8, 6, 5))
verify_instance_norm((8, 7, 6, 5, 4))


def _test_upsample_nearest():
scale = 2
in_shape = (1, 1, 3, 3)
Expand Down Expand Up @@ -1270,6 +1314,7 @@ def test_erf():
test_matmul()
test_gather()
test_lrn()
test_instance_norm()
test_upsample()
test_forward_min()
test_forward_max()
Expand Down

0 comments on commit 7d911f4

Please sign in to comment.