diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 357bc2da62af..02840295b5ba 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -468,17 +468,18 @@ def _mx_roi_align(inputs, attrs): new_attrs["layout"] = "NCHW" return _op.vision.roi_align(inputs[0], inputs[1], **new_attrs) -def _mx_upsampling(inputs, attrs): +def _mx_resize(inputs, attrs): scale_height = attrs.get_float("scale_height", None) scale_width = attrs.get_float("scale_width", None) height = attrs.get_int("height", 1) width = attrs.get_int("width", 1) + shape = ir_pass.infer_type(inputs[0]).checked_type.shape if scale_height is not None: - height = scale_height * inputs[0].shape[2] + height = (scale_height * shape[2]).astype("int32") if scale_width is not None: - width = scale_width * inputs[0].shape[3] - size = (inputs[0].shape[0], inputs[0].shape[1], height, width) - return _op.image.resize(inputs[0], size) + width = (scale_width * shape[3]).astype("int32") + size = (height, width) + return _op.image.resize(inputs[0], size, align_corners=True) def _mx_roi_pooling(inputs, attrs): new_attrs = {} @@ -770,7 +771,7 @@ def _mx_deformable_convolution(inputs, attrs): "SoftmaxActivation" : _mx_softmax_activation, "smooth_l1" : _mx_smooth_l1, # vision - "_contrib_BilinearResize2D" : _mx_upsampling, + "_contrib_BilinearResize2D" : _mx_resize, "_contrib_MultiBoxPrior" : _mx_multibox_prior, "_contrib_MultiBoxDetection" : _mx_multibox_detection, "_contrib_ROIAlign" : _mx_roi_align, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 797f8c025dcd..acf5a75a552e 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -505,6 +505,12 @@ def verify(xshape, yshape, y_data): verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]]) +def test_forward_bilinear_resize(): + # add tests including scale_height and scale_width when mxnet is updated to version 1.5 + data = mx.sym.var('data') + mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10) + verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10)) + if __name__ == '__main__': test_forward_mlp() @@ -543,3 +549,4 @@ def verify(xshape, yshape, y_data): test_forward_smooth_l1() test_forward_take() test_forward_gather_nd() + test_forward_bilinear_resize()