Skip to content

Commit

Permalink
fix test_lookup_table_v2_bf16_op (#60332)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingmingyyj committed Jan 2, 2024
1 parent b56d140 commit 7041276
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
22 changes: 10 additions & 12 deletions test/legacy_test/test_lookup_table_v2_bf16_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,26 @@
import unittest

import numpy as np
import test_lookup_table_bf16_op
from op_test import convert_uint16_to_float
from test_lookup_table_bf16_op import (
TestLookupTableBF16Op,
TestLookupTableBF16OpIds4D,
TestLookupTableBF16OpWIsSelectedRows,
TestLookupTableBF16OpWIsSelectedRows4DIds,
_lookup,
)

import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


class TestLookupTableV2BF16Op(TestLookupTableBF16Op):
class TestLookupTableV2BF16Op(test_lookup_table_bf16_op.TestLookupTableBF16Op):
def init_test(self):
self.op_type = "lookup_table_v2"
self.python_api = paddle.nn.functional.embedding
self.ids_shape = 4
self.mkldnn_data_type = "bfloat16"


class TestLookupTableV2BF16OpIds4D(TestLookupTableBF16OpIds4D):
class TestLookupTableV2BF16OpIds4D(
test_lookup_table_bf16_op.TestLookupTableBF16OpIds4D
):
def init_test(self):
self.op_type = "lookup_table_v2"
self.python_api = paddle.nn.functional.embedding
Expand All @@ -47,7 +43,7 @@ def init_test(self):


class TestLookupTableV2BF16OpWIsSelectedRows(
TestLookupTableBF16OpWIsSelectedRows
test_lookup_table_bf16_op.TestLookupTableBF16OpWIsSelectedRows
):
def init_test(self):
self.op_type = "lookup_table_v2"
Expand All @@ -56,7 +52,7 @@ def init_test(self):


class TestLookupTableV2BF16OpWIsSelectedRows4DIds(
TestLookupTableBF16OpWIsSelectedRows4DIds
test_lookup_table_bf16_op.TestLookupTableBF16OpWIsSelectedRows4DIds
):
def init_test(self):
self.op_type = "lookup_table_v2"
Expand Down Expand Up @@ -134,7 +130,9 @@ def test_embedding_weights(self):
@test_with_pir_api
def test_lookup_results(self):
lookup_result = convert_uint16_to_float(self.result[1])
lookup_ref = _lookup(self.w_fp32, self.ids, self.flat_ids, self.op_type)
lookup_ref = test_lookup_table_bf16_op._lookup(
self.w_fp32, self.ids, self.flat_ids, self.op_type
)
np.testing.assert_array_equal(lookup_result, lookup_ref)


Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ test_logcumsumexp_op
test_logit_op
test_logspace
test_logsumexp
test_lookup_table_v2_bf16_op
test_lookup_table_v2_op
test_lookup_table_v2_op_static_build
test_lrn_mkldnn_op
Expand Down

0 comments on commit 7041276

Please sign in to comment.