Skip to content

Commit

Permalink
[Dy2St] Make bool(Value) always throw error (PaddlePaddle#60902)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: zoooo0820 <zoooo0820@qq.com>
  • Loading branch information
2 people authored and eee4017 committed Jan 30, 2024
1 parent 2efef93 commit 2b0cec7
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 15 deletions.
16 changes: 10 additions & 6 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,10 @@ def replace_none(item):
return new_item, none_axes


def is_integer_or_scalar_tensor(ele):
def is_scalar_tensor(ele):
from .framework import Variable

if type(ele) is int:
return True
elif isinstance(ele, Variable):
if isinstance(ele, Variable):
if len(ele.shape) == 0 and ele.dtype != paddle.bool:
return True
elif isinstance(ele, paddle.pir.Value):
Expand Down Expand Up @@ -284,10 +282,9 @@ def parse_index(x, indices):
dim = 0
for i, slice_item in enumerate(indices):
start, end, step = None, None, None
if is_integer_or_scalar_tensor(slice_item):
if type(slice_item) is int:
if (
not is_tensor_array
and isinstance(slice_item, int)
and x.shape[dim] is not None
and x.shape[dim] >= 0
and slice_item >= x.shape[dim]
Expand All @@ -308,6 +305,13 @@ def parse_index(x, indices):
step = 1
end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
dim += 1
elif is_scalar_tensor(slice_item):
# not calculate result to reduce call times for slice OP.
decrease_axes.append(dim)
start = slice_item
step = 1
end = slice_item + 1
dim += 1
elif isinstance(slice_item, bool):
# single bool is advanced-indexing
none_axes.append(dim)
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/decomposition/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _prepare_python_api_arguments(op):
inputs = []
for x in op.operands():
input = x.source()
if input and input.initialized():
if input.initialized():
prev_op = input.get_defining_op()
if (
isinstance(prev_op, Operation)
Expand All @@ -111,7 +111,7 @@ def _check_prim_dynamic(op):
inputs = []
for x in op.operands():
input = x.source()
if input and input.initialized():
if input.initialized():
prev_op = input.get_defining_op()
if (
isinstance(prev_op, Operation)
Expand Down
8 changes: 7 additions & 1 deletion python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3733,7 +3733,13 @@ def triplet_margin_with_distance_loss(
swap_dist = distance_function(positive, negative)
negative_dist = paddle.minimum(negative_dist, swap_dist)

if not paddle.all(positive_dist > 0) or not paddle.all(negative_dist > 0):
if (
not isinstance(positive_dist, paddle.pir.Value)
and not paddle.all(positive_dist > 0)
) or (
not isinstance(negative_dist, paddle.pir.Value)
and not paddle.all(negative_dist > 0)
):
raise ValueError(
"The positive distance or negative distance should be greater than 0, "
"The distance functions should be checked."
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ def __init__(
assert beta1 is not None
assert beta2 is not None
assert epsilon is not None
if not 0 <= beta1 < 1:
if not isinstance(beta1, Value) and not 0 <= beta1 < 1:
raise ValueError("Invaild value of beta1, expect beta1 in [0,1).")
if not 0 <= beta2 < 1:
if not isinstance(beta1, Value) and not 0 <= beta2 < 1:
raise ValueError("Invaild value of beta2, expect beta2 in [0,1).")
if not 0 <= epsilon:
if not isinstance(beta1, Value) and not 0 <= epsilon:
raise ValueError("Invaild value of epsilon, expect epsilon >= 0.")
if not isinstance(weight_decay, float) and not isinstance(
weight_decay, (framework.Variable, Value)
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,13 @@ def _float_(self):
"2. If you want to run it in full graph mode, you need use Value directly, and do not use float(Value)."
)

def _bool_(self):
raise TypeError(
"bool(Value) is not supported in static graph mode. If you are using @to_static, you can try this:\n"
"1. If you want to get the value of Value, you can switch to non-fullgraph mode by setting @to_static(full_graph=True).\n"
"2. If you want to run it in full graph mode, you need use Value.astype(paddle.bool), and do not use bool(Value)."
)

def clone(self):
"""
Returns a new static Value, which is the clone of the original static
Expand Down Expand Up @@ -669,6 +676,7 @@ def value_hash(self):
),
('__float__', _float_),
('__int__', _int_),
('__bool__', _bool_),
]

global _already_patch_value
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5802,7 +5802,7 @@ def take_along_axis(arr, indices, axis, broadcast=True):
)

axis_max_size = arr.shape[axis]
if not (indices < axis_max_size).all():
if in_dynamic_mode() and not (indices < axis_max_size).all():
raise RuntimeError(
"one of element of indices is out of bounds for dimension {} with size {}".format(
axis, axis_max_size
Expand Down Expand Up @@ -5958,7 +5958,7 @@ def put_along_axis(
if elements == 1: # paddle.pir.Value has no attribute 'size'
values = paddle.broadcast_to(values, indices.shape)
axis_max_size = arr.shape[axis]
if not (indices < axis_max_size).all():
if in_dynamic_mode() and not (indices < axis_max_size).all():
raise RuntimeError(
"one of element of indices is out of bounds for dimension {} with size {}".format(
axis, axis_max_size
Expand Down
2 changes: 1 addition & 1 deletion test/dygraph_to_static/test_tensor_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def _set_test_func(self):
self.dygraph_func = dyfunc_with_for_1

def _set_expected_op_num(self):
self.expected_op_num = 29
self.expected_op_num = 27
self.expected_shape_op_num = 2
self.expected_slice_op_num = 3

Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_math_op_patch_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ def test_builtin_type_conversion(self):
int(x)
with self.assertRaises(TypeError):
float(x)
with self.assertRaises(TypeError):
bool(x)

def test_math_exists(self):
with paddle.pir_utils.IrGuard():
Expand Down

0 comments on commit 2b0cec7

Please sign in to comment.