diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index c1f6ea09900e25..49f1147b72cfb2 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -82,12 +82,10 @@ def replace_none(item): return new_item, none_axes -def is_integer_or_scalar_tensor(ele): +def is_scalar_tensor(ele): from .framework import Variable - if type(ele) is int: - return True - elif isinstance(ele, Variable): + if isinstance(ele, Variable): if len(ele.shape) == 0 and ele.dtype != paddle.bool: return True elif isinstance(ele, paddle.pir.Value): @@ -284,10 +282,9 @@ def parse_index(x, indices): dim = 0 for i, slice_item in enumerate(indices): start, end, step = None, None, None - if is_integer_or_scalar_tensor(slice_item): + if type(slice_item) is int: if ( not is_tensor_array - and isinstance(slice_item, int) and x.shape[dim] is not None and x.shape[dim] >= 0 and slice_item >= x.shape[dim] @@ -308,6 +305,13 @@ def parse_index(x, indices): step = 1 end = slice_item + 1 if slice_item != -1 else MAX_INTEGER dim += 1 + elif is_scalar_tensor(slice_item): + # not calculate result to reduce call times for slice OP. + decrease_axes.append(dim) + start = slice_item + step = 1 + end = slice_item + 1 + dim += 1 elif isinstance(slice_item, bool): # single bool is advanced-indexing none_axes.append(dim) diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index ed158eae331b3b..4a8fa8bd439e57 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -84,7 +84,7 @@ def _prepare_python_api_arguments(op): inputs = [] for x in op.operands(): input = x.source() - if input and input.initialized(): + if input.initialized(): prev_op = input.get_defining_op() if ( isinstance(prev_op, Operation) @@ -111,7 +111,7 @@ def _check_prim_dynamic(op): inputs = [] for x in op.operands(): input = x.source() - if input and input.initialized(): + if input.initialized(): prev_op = input.get_defining_op() if ( isinstance(prev_op, Operation) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index f4e44a0dd9cc0c..98605b04415dc3 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3733,7 +3733,13 @@ def triplet_margin_with_distance_loss( swap_dist = distance_function(positive, negative) negative_dist = paddle.minimum(negative_dist, swap_dist) - if not paddle.all(positive_dist > 0) or not paddle.all(negative_dist > 0): + if ( + not isinstance(positive_dist, paddle.pir.Value) + and not paddle.all(positive_dist > 0) + ) or ( + not isinstance(negative_dist, paddle.pir.Value) + and not paddle.all(negative_dist > 0) + ): raise ValueError( "The positive distance or negative distance should be greater than 0, " "The distance functions should be checked." diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 6077c60221eb1d..55b333630a745e 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -175,11 +175,11 @@ def __init__( assert beta1 is not None assert beta2 is not None assert epsilon is not None - if not 0 <= beta1 < 1: + if not isinstance(beta1, Value) and not 0 <= beta1 < 1: raise ValueError("Invaild value of beta1, expect beta1 in [0,1).") - if not 0 <= beta2 < 1: + if not isinstance(beta1, Value) and not 0 <= beta2 < 1: raise ValueError("Invaild value of beta2, expect beta2 in [0,1).") - if not 0 <= epsilon: + if not isinstance(beta1, Value) and not 0 <= epsilon: raise ValueError("Invaild value of epsilon, expect epsilon >= 0.") if not isinstance(weight_decay, float) and not isinstance( weight_decay, (framework.Variable, Value) diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index 9bd3ec29781df3..9ac8a55145aa5d 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -448,6 +448,13 @@ def _float_(self): "2. If you want to run it in full graph mode, you need use Value directly, and do not use float(Value)." ) + def _bool_(self): + raise TypeError( + "bool(Value) is not supported in static graph mode. If you are using @to_static, you can try this:\n" + "1. If you want to get the value of Value, you can switch to non-fullgraph mode by setting @to_static(full_graph=True).\n" + "2. If you want to run it in full graph mode, you need use Value.astype(paddle.bool), and do not use bool(Value)." + ) + def clone(self): """ Returns a new static Value, which is the clone of the original static @@ -669,6 +676,7 @@ def value_hash(self): ), ('__float__', _float_), ('__int__', _int_), + ('__bool__', _bool_), ] global _already_patch_value diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index d62771bb8ed534..8547ab0a26286b 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -5802,7 +5802,7 @@ def take_along_axis(arr, indices, axis, broadcast=True): ) axis_max_size = arr.shape[axis] - if not (indices < axis_max_size).all(): + if in_dynamic_mode() and not (indices < axis_max_size).all(): raise RuntimeError( "one of element of indices is out of bounds for dimension {} with size {}".format( axis, axis_max_size @@ -5958,7 +5958,7 @@ def put_along_axis( if elements == 1: # paddle.pir.Value has no attribute 'size' values = paddle.broadcast_to(values, indices.shape) axis_max_size = arr.shape[axis] - if not (indices < axis_max_size).all(): + if in_dynamic_mode() and not (indices < axis_max_size).all(): raise RuntimeError( "one of element of indices is out of bounds for dimension {} with size {}".format( axis, axis_max_size diff --git a/test/dygraph_to_static/test_tensor_shape.py b/test/dygraph_to_static/test_tensor_shape.py index 1bd2c6ffbebd5f..88b2425f45d120 100644 --- a/test/dygraph_to_static/test_tensor_shape.py +++ b/test/dygraph_to_static/test_tensor_shape.py @@ -748,7 +748,7 @@ def _set_test_func(self): self.dygraph_func = dyfunc_with_for_1 def _set_expected_op_num(self): - self.expected_op_num = 29 + self.expected_op_num = 27 self.expected_shape_op_num = 2 self.expected_slice_op_num = 3 diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index 78cdb68910e1b5..b76dbda83224a8 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -533,6 +533,8 @@ def test_builtin_type_conversion(self): int(x) with self.assertRaises(TypeError): float(x) + with self.assertRaises(TypeError): + bool(x) def test_math_exists(self): with paddle.pir_utils.IrGuard():