Skip to content

Commit

Permalink
deg2rad test passed (#60619)
Browse files Browse the repository at this point in the history
  • Loading branch information
changeyoung98 committed Jan 9, 2024
1 parent 47ecd81 commit 9982819
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions test/legacy_test/test_deg2rad.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()

Expand All @@ -32,10 +33,11 @@ def setUp(self):
self.x_shape = [6]
self.out_np = np.deg2rad(self.x_np)

@test_with_pir_api
def test_static_graph(self):
startup_program = base.Program()
train_program = base.Program()
with base.program_guard(startup_program, train_program):
startup_program = paddle.static.Program()
train_program = paddle.static.Program()
with paddle.static.program_guard(startup_program, train_program):
x = paddle.static.data(
name='input', dtype=self.x_dtype, shape=self.x_shape
)
Expand All @@ -48,11 +50,12 @@ def test_static_graph(self):
)
exe = base.Executor(place)
res = exe.run(
base.default_main_program(),
feed={'input': self.x_np},
fetch_list=[out],
)
self.assertTrue((np.array(out[0]) == self.out_np).all())
np.testing.assert_allclose(
np.array(res[0]), self.out_np, rtol=1e-05
)

def test_dygraph(self):
paddle.disable_static()
Expand All @@ -79,3 +82,7 @@ def test_dygraph(self):
np.testing.assert_allclose(np.pi, result2.numpy(), rtol=1e-05)

paddle.enable_static()


if __name__ == '__main__':
unittest.main()

0 comments on commit 9982819

Please sign in to comment.