Skip to content

Commit

Permalink
mend
Browse files Browse the repository at this point in the history
[Relay][Op] Add instance norm op
  • Loading branch information
bindog committed Sep 26, 2019
1 parent 79e392a commit c298b60
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,14 @@ def _mx_batch_norm(inputs, attrs):
return _op.nn.batch_norm(*inputs, **new_attrs)


def _mx_instance_norm(inputs, attrs):
assert len(inputs) == 3
new_attrs = {}
new_attrs["axis"] = attrs.get_int("axis", 1)
new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
return _op.nn.instance_norm(*inputs, **new_attrs)


def _mx_layer_norm(inputs, attrs):
assert len(inputs) == 3
if attrs.get_bool("output_mean_var", False):
Expand Down Expand Up @@ -1133,6 +1141,7 @@ def _mx_one_hot(inputs, attrs):
"Dropout" : _mx_dropout,
"BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm,
"InstanceNorm" : _mx_instance_norm,
"LayerNorm" : _mx_layer_norm,
"LRN" : _mx_lrn,
"L2Normalization" : _mx_l2_normalize,
Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,26 @@ def verify(shape, axis=1, fix_gamma=False):
verify((2, 3, 4, 5), fix_gamma=True)


def test_forward_instance_norm():
def verify(shape, axis=1, epsilon=1e-5):
x = np.random.uniform(size=shape).astype("float32")
gamma = np.random.uniform(size=(shape[axis])).astype("float32")
beta = np.random.uniform(size=(shape[axis])).astype("float32")
ref_res = mx.nd.InstanceNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), epsilon)
mx_sym = mx.sym.InstanceNorm(mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), epsilon)
shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, gamma, beta)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
verify((2, 3, 4, 5))
verify((32, 64, 80, 64))
verify((8, 6, 5))
verify((8, 7, 6, 5, 4))


def test_forward_layer_norm():
def verify(shape, axis=-1):
x = np.random.uniform(size=shape).astype("float32")
Expand Down Expand Up @@ -926,6 +946,7 @@ def verify(data_shape, kernel_size, stride, pad, num_filter):
test_forward_sequence_mask()
test_forward_contrib_div_sqrt_dim()
test_forward_batch_norm()
test_forward_instance_norm()
test_forward_layer_norm()
test_forward_one_hot()
test_forward_convolution()
Expand Down

0 comments on commit c298b60

Please sign in to comment.