From ca6c9def7a8e16fa62d6debce7a14db7aec64c86 Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Tue, 29 Sep 2020 00:18:21 +0800 Subject: [PATCH 1/3] fix bmm enforce equal batch --- python/paddle/tensor/linalg.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 15580b6618e6d..0b7ebee884cf7 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -850,6 +850,10 @@ def bmm(x, y, name=None): raise ValueError( "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}". format(x_shape, y_shape)) + if x_shape[0] != y_shape[0]: + raise ValueError( + "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}". + format(x_shape, y_shape)) helper = LayerHelper('bmm', **locals()) if in_dygraph_mode(): return core.ops.bmm(x, y) From 394240a5ae65ff9397d52c579831c55987be2db9 Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Tue, 29 Sep 2020 12:45:29 +0800 Subject: [PATCH 2/3] add ut --- python/paddle/fluid/tests/unittests/test_bmm_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_bmm_op.py b/python/paddle/fluid/tests/unittests/test_bmm_op.py index cb1b3ded53472..9d475c323991b 100644 --- a/python/paddle/fluid/tests/unittests/test_bmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_bmm_op.py @@ -79,8 +79,10 @@ def test_api_error(self): y_data = np.arange(16, dtype='float32').reshape((2, 4, 2)) y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4)) y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2)) + y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 2, 4)) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2) + self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong3) if __name__ == "__main__": From 572e25f2551d04593e973aa6d8811265c0eeeed0 Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Tue, 29 Sep 2020 17:58:49 +0800 Subject: [PATCH 3/3] fix ut --- python/paddle/fluid/tests/unittests/test_bmm_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_bmm_op.py b/python/paddle/fluid/tests/unittests/test_bmm_op.py index 9d475c323991b..a1c8266842087 100644 --- a/python/paddle/fluid/tests/unittests/test_bmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_bmm_op.py @@ -79,7 +79,7 @@ def test_api_error(self): y_data = np.arange(16, dtype='float32').reshape((2, 4, 2)) y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4)) y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2)) - y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 2, 4)) + y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 4, 2)) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong3)