Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR OpTest Fix No.16】 fix test_lookup_table_v2_bf16_op #60332

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions test/legacy_test/test_lookup_table_v2_bf16_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,26 @@
import unittest

import numpy as np
import test_lookup_table_bf16_op
from op_test import convert_uint16_to_float
from test_lookup_table_bf16_op import (
TestLookupTableBF16Op,
TestLookupTableBF16OpIds4D,
TestLookupTableBF16OpWIsSelectedRows,
TestLookupTableBF16OpWIsSelectedRows4DIds,
_lookup,
)

import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


class TestLookupTableV2BF16Op(TestLookupTableBF16Op):
class TestLookupTableV2BF16Op(test_lookup_table_bf16_op.TestLookupTableBF16Op):
def init_test(self):
self.op_type = "lookup_table_v2"
self.python_api = paddle.nn.functional.embedding
self.ids_shape = 4
self.mkldnn_data_type = "bfloat16"


class TestLookupTableV2BF16OpIds4D(TestLookupTableBF16OpIds4D):
class TestLookupTableV2BF16OpIds4D(
test_lookup_table_bf16_op.TestLookupTableBF16OpIds4D
):
def init_test(self):
self.op_type = "lookup_table_v2"
self.python_api = paddle.nn.functional.embedding
Expand All @@ -47,7 +43,7 @@ def init_test(self):


class TestLookupTableV2BF16OpWIsSelectedRows(
TestLookupTableBF16OpWIsSelectedRows
test_lookup_table_bf16_op.TestLookupTableBF16OpWIsSelectedRows
):
def init_test(self):
self.op_type = "lookup_table_v2"
Expand All @@ -56,7 +52,7 @@ def init_test(self):


class TestLookupTableV2BF16OpWIsSelectedRows4DIds(
TestLookupTableBF16OpWIsSelectedRows4DIds
test_lookup_table_bf16_op.TestLookupTableBF16OpWIsSelectedRows4DIds
):
def init_test(self):
self.op_type = "lookup_table_v2"
Expand Down Expand Up @@ -134,7 +130,9 @@ def test_embedding_weights(self):
@test_with_pir_api
def test_lookup_results(self):
lookup_result = convert_uint16_to_float(self.result[1])
lookup_ref = _lookup(self.w_fp32, self.ids, self.flat_ids, self.op_type)
lookup_ref = test_lookup_table_bf16_op._lookup(
self.w_fp32, self.ids, self.flat_ids, self.op_type
)
np.testing.assert_array_equal(lookup_result, lookup_ref)


Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ test_logcumsumexp_op
test_logit_op
test_logspace
test_logsumexp
test_lookup_table_v2_bf16_op
test_lookup_table_v2_op
test_lookup_table_v2_op_static_build
test_lrn_mkldnn_op
Expand Down