Skip to content

Commit

Permalink
[Dy2Stat]Filter UserWarings while @to_static (PaddlePaddle#45754)
Browse files Browse the repository at this point in the history
* [Dy2Stat]Filter UserWarings while @to_static

* only filter DeprecationWarning

* fix unittest
  • Loading branch information
Aurelius84 authored and Caozhou1995 committed Sep 8, 2022
1 parent 2190212 commit 5f9e0b3
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
3 changes: 2 additions & 1 deletion python/paddle/fluid/dygraph/dygraph_to_static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ def _inject_import_statements():
import_statements = [
"import paddle", "from paddle import Tensor",
"import paddle.fluid as fluid", "import paddle.jit.dy2static as _jst",
"from typing import *", "import numpy as np"
"from typing import *", "import numpy as np", "import warnings",
"warnings.filterwarnings('ignore', category=DeprecationWarning)"
]
return '\n'.join(import_statements) + '\n'

Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,8 @@ def __impl__(self, other_var):
"If your code works well in the older versions but crashes in this version, try to use "
"%s(X, Y, axis=0) instead of %s. This transitional warning will be dropped in the future."
% (file_name, line_num, EXPRESSION_MAP[method_name],
op_type, op_type, EXPRESSION_MAP[method_name]))
op_type, op_type, EXPRESSION_MAP[method_name]),
category=DeprecationWarning)
current_block(self).append_op(type=op_type,
inputs={
'X': [self],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def set_test_func(self):
self.func = simple_func

def set_static_lineno(self):
self.static_abs_lineno_list = [7, 8, 9]
self.static_abs_lineno_list = [9, 10, 11]

def set_dygraph_info(self):
self.line_num = 3
Expand Down Expand Up @@ -149,7 +149,7 @@ def set_test_func(self):
self.func = nested_func

def set_static_lineno(self):
self.static_abs_lineno_list = [7, 9, 10, 11, 12]
self.static_abs_lineno_list = [9, 11, 12, 13, 14]

def set_dygraph_info(self):
self.line_num = 5
Expand All @@ -174,7 +174,7 @@ def set_test_func(self):
self.func = decorated_func

def set_static_lineno(self):
self.static_abs_lineno_list = [7, 8]
self.static_abs_lineno_list = [9, 10]

def set_dygraph_info(self):
self.line_num = 2
Expand Down Expand Up @@ -208,7 +208,7 @@ def set_test_func(self):
self.func = decorated_func2

def set_static_lineno(self):
self.static_abs_lineno_list = [7, 8]
self.static_abs_lineno_list = [9, 10]

def set_dygraph_info(self):
self.line_num = 2
Expand Down

0 comments on commit 5f9e0b3

Please sign in to comment.