diff --git a/test/legacy_test/test_lookup_table_v2_bf16_op.py b/test/legacy_test/test_lookup_table_v2_bf16_op.py index 44a2f1881b086..ff9fcebef4f65 100644 --- a/test/legacy_test/test_lookup_table_v2_bf16_op.py +++ b/test/legacy_test/test_lookup_table_v2_bf16_op.py @@ -15,14 +15,8 @@ import unittest import numpy as np +import test_lookup_table_bf16_op from op_test import convert_uint16_to_float -from test_lookup_table_bf16_op import ( - TestLookupTableBF16Op, - TestLookupTableBF16OpIds4D, - TestLookupTableBF16OpWIsSelectedRows, - TestLookupTableBF16OpWIsSelectedRows4DIds, - _lookup, -) import paddle from paddle import base @@ -30,7 +24,7 @@ from paddle.pir_utils import test_with_pir_api -class TestLookupTableV2BF16Op(TestLookupTableBF16Op): +class TestLookupTableV2BF16Op(test_lookup_table_bf16_op.TestLookupTableBF16Op): def init_test(self): self.op_type = "lookup_table_v2" self.python_api = paddle.nn.functional.embedding @@ -38,7 +32,9 @@ def init_test(self): self.mkldnn_data_type = "bfloat16" -class TestLookupTableV2BF16OpIds4D(TestLookupTableBF16OpIds4D): +class TestLookupTableV2BF16OpIds4D( + test_lookup_table_bf16_op.TestLookupTableBF16OpIds4D +): def init_test(self): self.op_type = "lookup_table_v2" self.python_api = paddle.nn.functional.embedding @@ -47,7 +43,7 @@ def init_test(self): class TestLookupTableV2BF16OpWIsSelectedRows( - TestLookupTableBF16OpWIsSelectedRows + test_lookup_table_bf16_op.TestLookupTableBF16OpWIsSelectedRows ): def init_test(self): self.op_type = "lookup_table_v2" @@ -56,7 +52,7 @@ def init_test(self): class TestLookupTableV2BF16OpWIsSelectedRows4DIds( - TestLookupTableBF16OpWIsSelectedRows4DIds + test_lookup_table_bf16_op.TestLookupTableBF16OpWIsSelectedRows4DIds ): def init_test(self): self.op_type = "lookup_table_v2" @@ -134,7 +130,9 @@ def test_embedding_weights(self): @test_with_pir_api def test_lookup_results(self): lookup_result = convert_uint16_to_float(self.result[1]) - lookup_ref = _lookup(self.w_fp32, self.ids, self.flat_ids, self.op_type) + lookup_ref = test_lookup_table_bf16_op._lookup( + self.w_fp32, self.ids, self.flat_ids, self.op_type + ) np.testing.assert_array_equal(lookup_result, lookup_ref) diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index 9e4de5ccffcfc..8082c48194a8f 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -184,6 +184,7 @@ test_logcumsumexp_op test_logit_op test_logspace test_logsumexp +test_lookup_table_v2_bf16_op test_lookup_table_v2_op test_lookup_table_v2_op_static_build test_lrn_mkldnn_op