Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] Make bool(Value) always throw error #60902

Merged
16 changes: 10 additions & 6 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,10 @@ def replace_none(item):
return new_item, none_axes


def is_integer_or_scalar_tensor(ele):
def is_scalar_tensor(ele):
from .framework import Variable

if type(ele) is int:
return True
elif isinstance(ele, Variable):
if isinstance(ele, Variable):
if len(ele.shape) == 0 and ele.dtype != paddle.bool:
return True
elif isinstance(ele, paddle.pir.Value):
Expand Down Expand Up @@ -285,10 +283,9 @@ def parse_index(x, indices):
dim = 0
for i, slice_item in enumerate(indices):
start, end, step = None, None, None
if is_integer_or_scalar_tensor(slice_item):
if type(slice_item) is int:
if (
not is_tensor_array
and isinstance(slice_item, int)
and x.shape[dim] is not None
and x.shape[dim] >= 0
and slice_item >= x.shape[dim]
Expand All @@ -309,6 +306,13 @@ def parse_index(x, indices):
step = 1
end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
dim += 1
elif is_scalar_tensor(slice_item):
# not calculate result to reduce call times for slice OP.
decrease_axes.append(dim)
start = slice_item
step = 1
end = slice_item + 1
dim += 1
elif isinstance(slice_item, bool):
# single bool is advanced-indexing
none_axes.append(dim)
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/decomposition/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _prepare_python_api_arguments(op):
inputs = []
for x in op.operands():
input = x.source()
if input and input.initialized():
if input.initialized():
prev_op = input.get_defining_op()
if (
isinstance(prev_op, Operation)
Expand All @@ -111,7 +111,7 @@ def _check_prim_dynamic(op):
inputs = []
for x in op.operands():
input = x.source()
if input and input.initialized():
if input.initialized():
prev_op = input.get_defining_op()
if (
isinstance(prev_op, Operation)
Expand Down
8 changes: 7 additions & 1 deletion python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3733,7 +3733,13 @@ def triplet_margin_with_distance_loss(
swap_dist = distance_function(positive, negative)
negative_dist = paddle.minimum(negative_dist, swap_dist)

if not paddle.all(positive_dist > 0) or not paddle.all(negative_dist > 0):
if (
not isinstance(positive_dist, paddle.pir.Value)
and not paddle.all(positive_dist > 0)
) or (
not isinstance(negative_dist, paddle.pir.Value)
and not paddle.all(negative_dist > 0)
):
raise ValueError(
"The positive distance or negative distance should be greater than 0, "
"The distance functions should be checked."
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ def __init__(
assert beta1 is not None
assert beta2 is not None
assert epsilon is not None
if not 0 <= beta1 < 1:
if not isinstance(beta1, Value) and not 0 <= beta1 < 1:
raise ValueError("Invaild value of beta1, expect beta1 in [0,1).")
if not 0 <= beta2 < 1:
if not isinstance(beta1, Value) and not 0 <= beta2 < 1:
raise ValueError("Invaild value of beta2, expect beta2 in [0,1).")
if not 0 <= epsilon:
if not isinstance(beta1, Value) and not 0 <= epsilon:
raise ValueError("Invaild value of epsilon, expect epsilon >= 0.")
if not isinstance(weight_decay, float) and not isinstance(
weight_decay, (framework.Variable, Value)
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,13 @@ def _float_(self):
"2. If you want to run it in full graph mode, you need use Value directly, and do not use float(Value)."
)

def _bool_(self):
raise TypeError(
"bool(Value) is not supported in static graph mode. If you are using @to_static, you can try this:\n"
"1. If you want to get the value of Value, you can switch to non-fullgraph mode by setting @to_static(full_graph=True).\n"
"2. If you want to run it in full graph mode, you need use Value.astype(paddle.bool), and do not use bool(Value)."
)

def clone(self):
"""
Returns a new static Value, which is the clone of the original static
Expand Down Expand Up @@ -669,6 +676,7 @@ def value_hash(self):
),
('__float__', _float_),
('__int__', _int_),
('__bool__', _bool_),
]

global _already_patch_value
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5802,7 +5802,7 @@ def take_along_axis(arr, indices, axis, broadcast=True):
)

axis_max_size = arr.shape[axis]
if not (indices < axis_max_size).all():
if in_dynamic_mode() and not (indices < axis_max_size).all():
raise RuntimeError(
"one of element of indices is out of bounds for dimension {} with size {}".format(
axis, axis_max_size
Expand Down Expand Up @@ -5958,7 +5958,7 @@ def put_along_axis(
if elements == 1: # paddle.pir.Value has no attribute 'size'
values = paddle.broadcast_to(values, indices.shape)
axis_max_size = arr.shape[axis]
if not (indices < axis_max_size).all():
if in_dynamic_mode() and not (indices < axis_max_size).all():
raise RuntimeError(
"one of element of indices is out of bounds for dimension {} with size {}".format(
axis, axis_max_size
Expand Down
2 changes: 1 addition & 1 deletion test/dygraph_to_static/test_tensor_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def _set_test_func(self):
self.dygraph_func = dyfunc_with_for_1

def _set_expected_op_num(self):
self.expected_op_num = 29
self.expected_op_num = 27
self.expected_shape_op_num = 2
self.expected_slice_op_num = 3

Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_math_op_patch_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ def test_builtin_type_conversion(self):
int(x)
with self.assertRaises(TypeError):
float(x)
with self.assertRaises(TypeError):
bool(x)

def test_math_exists(self):
with paddle.pir_utils.IrGuard():
Expand Down