From b6630715c22999cb9fe5224988dec6e409622a72 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 25 May 2023 10:36:36 +0000 Subject: [PATCH 01/27] WIP: start writing combined indexing get --- python/paddle/fluid/variable_index.py | 50 ++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index dae0deb135c3d..818e18f17bc61 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -293,7 +293,7 @@ def is_integer_or_scalar_tensor(ele): "1-D Tensor will be treat as advanced indexing in future version. Currently, 1-D Tensor means a scalar, not vector, and please modify it to 0-D Tensor. If advanced indexing is needed, please use `export FLAGS_set_to_1d=False` to set the flag." ) return True - if len(ele.shape) == 0: + if len(ele.shape) == 0 and ele.dtype != paddle.bool: return True return False @@ -858,3 +858,51 @@ def idx_not_empty(var, item, value): cond(item.any(), lambda: idx_not_empty(var, item, value)) return var + + +def _getitem_iter(x, indices): + """ + [WIP]: support __getitem__ by iteration strategy. combined indexing will be support by this. + + Args: + x(Tensor): Tensor to be indexing. + indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched. + """ + advanced_index = [] + + # for slice/stride slice OP + decrease_axes = [] + axes = [] + starts = [] + ends = [] + steps = [] + + indices = replace_ndarray(indices) + indices = replace_ellipsis(x, indices) + indices, none_axes = replace_none(indices) + + if not isinstance(indices, tuple): + indices = (indices,) + + for dim, slice_item in enumerate(indices): + if is_integer_or_scalar_tensor(slice_item): + # not calculate result to reduce call times for slice OP. + decrease_axes.append(dim) + start = slice_item + step = 1 + end = slice_item + 1 if slice_item != -1 else MAX_INTEGER + elif isinstance(slice_item, slice): + start = slice_item.start + end = slice_item.stop + step = slice_item.step + + if start is None and end is None and step is None: + continue + + if start is None: + start = 0 if step > 0 else MAX_INTEGER + if end is None: + end = MAX_INTEGER if step > 0 else -1 + step = 1 if step is None else step + elif isinstance(slice_item, list): + pass From 1bcf667b27d6904841e1a2dca755fbe754780707 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Mon, 5 Jun 2023 08:28:07 +0000 Subject: [PATCH 02/27] list/tuple/Variable --- python/paddle/fluid/variable_index.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 818e18f17bc61..7a2201715b52a 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -868,7 +868,7 @@ def _getitem_iter(x, indices): x(Tensor): Tensor to be indexing. indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched. """ - advanced_index = [] + advanced_index = [] # content is (dim, index) # for slice/stride slice OP decrease_axes = [] @@ -891,6 +891,10 @@ def _getitem_iter(x, indices): start = slice_item step = 1 end = slice_item + 1 if slice_item != -1 else MAX_INTEGER + elif isinstance(slice_item, bool): + # single bool is advanced-indexing + none_axes.append(dim) + advanced_index.append((dim, paddle.to_tensor(slice_item))) elif isinstance(slice_item, slice): start = slice_item.start end = slice_item.stop @@ -904,5 +908,7 @@ def _getitem_iter(x, indices): if end is None: end = MAX_INTEGER if step > 0 else -1 step = 1 if step is None else step - elif isinstance(slice_item, list): - pass + elif isinstance(slice_item, (list, tuple)): + advanced_index.append((dim, paddle.to_tensor(slice_item))) + elif isinstance(slice_item, paddle.fluid.Variable): + advanced_index.append((dim, slice_item)) From 2d00290a4baab7eeff7b90f2758793a233dba0f6 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 14 Jun 2023 03:55:58 +0000 Subject: [PATCH 03/27] getitem 80% --- python/paddle/fluid/variable_index.py | 140 +++++++++++++++++++++++++- 1 file changed, 138 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 7a2201715b52a..e410feabd7303 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -18,6 +18,7 @@ from . import core import paddle import warnings +from .framework import default_main_program, Variable MAX_INTEGER = 2**31 - 1 @@ -370,7 +371,6 @@ def _getitem_impl_(var, item): Returns: Sliced variable """ - from .framework import default_main_program, Variable if isinstance(item, list): if not is_one_dim_list(item, int): @@ -644,7 +644,6 @@ def _setitem_for_tensor_array(var, item, value): def _setitem_impl_(var, item, value): - from .framework import default_main_program, Variable from paddle.fluid import core if var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: @@ -860,6 +859,23 @@ def idx_not_empty(var, item, value): return var +def deal_advanced_index(ori_tensor, indices): + """ + Transpose origin Tensor and indices to the front. + """ + transed_dim = [] + transed_index = [] + for i, indice in enumerate(indices): + if indice is not None: + transed_dim.append(i) + transed_index.append(indice[1]) + for i, indice in enumerate(indices): + if indice is None: + transed_dim.append(i) + transed_tensor = ori_tensor.transpose(transed_dim) + return transed_tensor, transed_index + + def _getitem_iter(x, indices): """ [WIP]: support __getitem__ by iteration strategy. combined indexing will be support by this. @@ -876,7 +892,11 @@ def _getitem_iter(x, indices): starts = [] ends = [] steps = [] + use_strided_slice = False + has_advanced_index = False + # step1 : replace ndarray / None / ellipsis to normal elemement + # and wrap multiple into one tuple indices = replace_ndarray(indices) indices = replace_ellipsis(x, indices) indices, none_axes = replace_none(indices) @@ -884,17 +904,21 @@ def _getitem_iter(x, indices): if not isinstance(indices, tuple): indices = (indices,) + # step2: Traverse index elements and record them. for dim, slice_item in enumerate(indices): + start, end, step = None, None, None if is_integer_or_scalar_tensor(slice_item): # not calculate result to reduce call times for slice OP. decrease_axes.append(dim) start = slice_item step = 1 end = slice_item + 1 if slice_item != -1 else MAX_INTEGER + advanced_index.append(None) elif isinstance(slice_item, bool): # single bool is advanced-indexing none_axes.append(dim) advanced_index.append((dim, paddle.to_tensor(slice_item))) + has_advanced_index = True elif isinstance(slice_item, slice): start = slice_item.start end = slice_item.stop @@ -908,7 +932,119 @@ def _getitem_iter(x, indices): if end is None: end = MAX_INTEGER if step > 0 else -1 step = 1 if step is None else step + advanced_index.append(None) elif isinstance(slice_item, (list, tuple)): advanced_index.append((dim, paddle.to_tensor(slice_item))) + has_advanced_index = True elif isinstance(slice_item, paddle.fluid.Variable): + # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. advanced_index.append((dim, slice_item)) + has_advanced_index = True + else: + raise IndexError( + "Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {}.".format( + slice_item + ) + ) + if not (start is None or end is None or step is None): + starts.append(start) + ends.append(end) + steps.append(step) + axes.append(dim) + use_strided_slice = True if step != 1 else use_strided_slice + + # step3: Dealing with basic indexing + if len(axes) > 0: + op_type = "strided_slice" if use_strided_slice else "slice" + inputs = {'X': [x]} if use_strided_slice else {'Input': [x]} + attrs = { + 'axes': axes, + 'starts': [], + 'ends': [], + 'decrease_axis': decrease_axes, + } + if use_strided_slice: + attrs['strides'] = [] + infer_flags = [1] * len(axes) + deal_attrs( + attrs, starts, "starts", "StartsTensorList", inputs, infer_flags + ) + deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags) + deal_attrs( + attrs, steps, "strides", "StridesTensorList", inputs, infer_flags + ) + attrs['infer_flags'] = infer_flags + + if paddle.in_dynamic_mode(): + if "StartsTensorList" in inputs.keys(): + st = inputs['StartsTensorList'] + else: + st = attrs['starts'] + if "EndsTensorList" in inputs.keys(): + end = inputs['EndsTensorList'] + else: + end = attrs['ends'] + if "StridesTensorList" in inputs.keys(): + stride = inputs['StridesTensorList'] + if use_strided_slice: + out = paddle._C_ops.strided_slice(x, axes, st, end, stride) + else: + out = paddle._C_ops.slice( + x, + axes, + st, + end, + attrs['infer_flags'], + attrs['decrease_axis'], + ) + else: + target_block = default_main_program().current_block() + + slice_out_var = target_block.create_var( + name=unique_name.generate_with_ignorable_key( + x.name + "_" + op_type + ), + dtype=x.dtype, + ) + target_block.append_op( + type=op_type, + inputs=inputs, + outputs={'Out': [slice_out_var]}, + attrs=attrs, + ) + out = slice_out_var + else: + out = x + + # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D + # with FLAGS_set_to_1d=True. In this case, one `None` should be pop out, + # otherwise the output shape will be not correct. + set_to_1d = paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d'] + if set_to_1d and len(decrease_axes) == len(x.shape): + warnings.warn( + "Warning: In Tensor '__getitem__', if the number of scalar elements in the index is equal to the rank of the Tensor, the output should be 0-D. In order to be consistent with the behavior of previous versions, it will be processed to 1-D. But it is not correct and will be removed in release 2.6. If 1-D is still wanted, please modify the index element from scalar to slice (e.g. 'x[i]' => 'x[i:i+1]')." + ) + none_axes = none_axes[1:] + + if len(none_axes) > 0: + # Deal with cases that decrease_axes is not empty + # For example: + # # x.shape: (2,3,4) + # out = x[0, 0:2, None] # out.shape : (2, 1, 4) + for idx, axis in enumerate(none_axes): + l = len([i for i in decrease_axes if i < axis]) + new_axis = axis - l + none_axes[idx] = new_axis + + out = paddle.unsqueeze(out, axis=none_axes) + + # step4: Dealing with advanced indexing + if has_advanced_index: + transed_tensor, adjusted_advanced_index = deal_advanced_index( + out, advanced_index + ) + + # TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently + out = paddle.gather_nd(transed_tensor, adjusted_advanced_index) + + return out From 6dfd5be2cc288ea8fd159baab1908fab7850c728 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 15 Jun 2023 06:40:17 +0000 Subject: [PATCH 04/27] add setitem --- python/paddle/fluid/variable_index.py | 249 +++++++++++++++++++++++--- 1 file changed, 227 insertions(+), 22 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index e410feabd7303..663873d7cb5f7 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -859,7 +859,7 @@ def idx_not_empty(var, item, value): return var -def deal_advanced_index(ori_tensor, indices): +def deal_advanced_index(ori_tensor, indices, with_transback=False): """ Transpose origin Tensor and indices to the front. """ @@ -873,20 +873,13 @@ def deal_advanced_index(ori_tensor, indices): if indice is None: transed_dim.append(i) transed_tensor = ori_tensor.transpose(transed_dim) - return transed_tensor, transed_index + trans_back_dim = np.argsort(transed_dim) if with_transback else [] + return transed_tensor, transed_index, trans_back_dim -def _getitem_iter(x, indices): - """ - [WIP]: support __getitem__ by iteration strategy. combined indexing will be support by this. - - Args: - x(Tensor): Tensor to be indexing. - indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched. - """ +def parse_index(indices): advanced_index = [] # content is (dim, index) - - # for slice/stride slice OP + # for set_value / slice / strided_slice OP decrease_axes = [] axes = [] starts = [] @@ -895,8 +888,6 @@ def _getitem_iter(x, indices): use_strided_slice = False has_advanced_index = False - # step1 : replace ndarray / None / ellipsis to normal elemement - # and wrap multiple into one tuple indices = replace_ndarray(indices) indices = replace_ellipsis(x, indices) indices, none_axes = replace_none(indices) @@ -904,7 +895,6 @@ def _getitem_iter(x, indices): if not isinstance(indices, tuple): indices = (indices,) - # step2: Traverse index elements and record them. for dim, slice_item in enumerate(indices): start, end, step = None, None, None if is_integer_or_scalar_tensor(slice_item): @@ -927,12 +917,26 @@ def _getitem_iter(x, indices): if start is None and end is None and step is None: continue + if not isinstance(step, Variable) and step == 0: + raise ValueError( + "When assign a value to a paddle.Tensor, step can not be 0, " + "but received step is {}.".format(step) + ) + + if isinstance(step, Variable) and (start is None or end is None): + raise ValueError( + "When assign a value to a paddle.Tensor, it's not supported that " + "the start or end is None when the type of step is paddle.Tensor." + ) + if start is None: start = 0 if step > 0 else MAX_INTEGER if end is None: end = MAX_INTEGER if step > 0 else -1 step = 1 if step is None else step + advanced_index.append(None) + elif isinstance(slice_item, (list, tuple)): advanced_index.append((dim, paddle.to_tensor(slice_item))) has_advanced_index = True @@ -952,9 +956,177 @@ def _getitem_iter(x, indices): steps.append(step) axes.append(dim) use_strided_slice = True if step != 1 else use_strided_slice + return ( + starts, + ends, + steps, + axes, + none_axes, + decrease_axes, + advanced_index, + has_advanced_index, + use_strided_slice, + ) - # step3: Dealing with basic indexing - if len(axes) > 0: + +def _setitem_static(x, indices, values): + """ + [WIP]: support __setitem__ by iteration strategy. combined indexing will be support by this. + In dynamic mode, this function will modify the value at input tensor, returning same Tensor as input. + But it will return a new Tensor with assigned value in static mode. + + Args: + x(Tensor): Tensor to be set value. + indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched. + values(Tensor|Number|Ndarray): values to be assigned to the x. + """ + if x.type == paddle.fluid.core.VarDesc.VarType.LOD_TENSOR_ARRAY: + return _setitem_for_tensor_array(x, indices, values) + + # step1: parsing the index and recording them + ( + starts, + ends, + steps, + axes, + none_axes, + decrease_axes, + advanced_index, + has_advanced_index, + use_strided_slice, + ) = parse_index(indices) + + inputs = {'Input': x} + attrs = { + 'axes': axes, + 'starts': starts, + 'ends': ends, + 'steps': steps, + 'decrease_axes': decrease_axes, + 'none_axes': none_axes, + } + if paddle.utils._contain_var(starts): + inputs['StartsTensorList'] = paddle.utils._convert_to_tensor_list( + starts + ) + del attrs['starts'] + if paddle.utils._contain_var(ends): + inputs['EndsTensorList'] = paddle.utils._convert_to_tensor_list(ends) + del attrs['ends'] + if paddle.utils._contain_var(steps): + inputs['StepsTensorList'] = paddle.utils._convert_to_tensor_list(steps) + del attrs['steps'] + + if not has_advanced_index: + # step2. Parse values + dtype = x.dtype + attrs['dtype'] = dtype + + from .data_feeder import convert_dtype + + if isinstance(values, (bool, int, float, complex)): + values = np.array([values]).astype(convert_dtype(dtype)) + + if isinstance(values, np.ndarray): + shape = list(values.shape) + values = values.ravel().tolist() + attrs["values"] = values + attrs["shape"] = shape + + elif isinstance(values, Variable): + inputs["ValueTensor"] = values + else: + raise TypeError( + "Only support to assign an Number, numpy.ndarray or " + "paddle.Tensor to a paddle.Tensor, but received {}".format( + type(values) + ) + ) + + # step3.1: Only basic indexing, use OP set_value to set value. + if paddle.in_dynamic_mode(): + x._bump_inplace_version() + out = x + else: + helper = paddle.fluid.layer_helper.LayerHelper( + 'set_value', **locals() + ) + out = helper.create_variable_for_type_inference(dtype=dtype) + cur_block = default_main_program().current_block() + cur_block.append_op( + type="set_value", + inputs=inputs, + outputs={'Out': out}, + attrs=attrs, + inplace_map={"Input": "Out"}, + ) + return out + else: + # step3.2: Case for there are advanced indexing. + # 1. get __getitem__ result of basic indexing; + # 2. transpose original tensor so that the axis with advanced indexing will come to the first; + # 3. assign values to the sliced result by index_put OP; + # 4. transpose back and assign the result to original tensor by set_value OP. + + sub_tensor = get_tensor_with_basic_indexing( + x, + axes, + starts, + ends, + steps, + decrease_axes, + none_axes, + use_strided_slice, + ) + ( + transed_sub_tensor, + adjusted_advanced_index, + transback_dim, + ) = deal_advanced_index(sub_tensor, advanced_index) + if not isinstance(values, Variable): + values = paddle.assign(values) + transed_sub_tensor = transed_sub_tensor.index_put( + adjusted_advanced_index, values + ) + + # NOTE(zoooo0820): now basic indexing of __getitem__ will return a new Tensor both in dynamic and static mode + # After strided is ready and basic indexing returns view of Tensor in dynamic mode. The code shoule be changed + # for dynamic mode. + if paddle.in_dynamic_mode(): + transed_sub_tensor.index_put_(adjusted_advanced_index, values) + else: + transed_sub_tensor = transed_sub_tensor.index_put( + adjusted_advanced_index, values + ) + + transback_sub_tensor = transed_sub_tensor.transpose(transback_dim) + + inputs["ValueTensor"] = transback_sub_tensor + if paddle.in_dynamic_mode(): + x._bump_inplace_version() + out = x + else: + helper = paddle.fluid.layer_helper.LayerHelper( + 'set_value', **locals() + ) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + cur_block = default_main_program().current_block() + cur_block.append_op( + type="set_value", + inputs=inputs, + outputs={'Out': out}, + attrs=attrs, + inplace_map={"Input": "Out"}, + ) + return out + + +def get_tensor_with_basic_indexing( + x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice +): + if len(axes) == 0: + out = x + else: op_type = "strided_slice" if use_strided_slice else "slice" inputs = {'X': [x]} if use_strided_slice else {'Input': [x]} attrs = { @@ -1013,9 +1185,6 @@ def _getitem_iter(x, indices): attrs=attrs, ) out = slice_out_var - else: - out = x - # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D # with FLAGS_set_to_1d=True. In this case, one `None` should be pop out, # otherwise the output shape will be not correct. @@ -1038,9 +1207,45 @@ def _getitem_iter(x, indices): out = paddle.unsqueeze(out, axis=none_axes) - # step4: Dealing with advanced indexing + return out + + +def _getitem_static(x, indices): + """ + [WIP]: support __getitem__ by iteration strategy. combined indexing will be support by this. + + Args: + x(Tensor): Tensor to be indexing. + indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched. + """ + # step1: parsing the index and recording them + ( + starts, + ends, + steps, + axes, + none_axes, + decrease_axes, + advanced_index, + has_advanced_index, + use_strided_slice, + ) = parse_index(indices) + + # step2: Dealing with basic indexing + out = get_tensor_with_basic_indexing( + x, + axes, + starts, + ends, + steps, + decrease_axes, + none_axes, + use_strided_slice, + ) + + # step3: Dealing with advanced indexing if has_advanced_index: - transed_tensor, adjusted_advanced_index = deal_advanced_index( + transed_tensor, adjusted_advanced_index, _ = deal_advanced_index( out, advanced_index ) From 1aeb04e7845ce6887105e670944f7d8f2d1da1d0 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 15 Jun 2023 07:43:43 +0000 Subject: [PATCH 05/27] add some unittest for setitem --- test/CMakeLists.txt | 1 + test/indexing/CMakeLists.txt | 9 ++++ test/indexing/test_setitem.py | 96 +++++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+) create mode 100644 test/indexing/CMakeLists.txt create mode 100644 test/indexing/test_setitem.py diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 975446b6002ae..06f461fd00264 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -119,6 +119,7 @@ if(WITH_TESTING) add_subdirectory(ipu) endif() add_subdirectory(ir) + add_subdirectory(indexing) add_subdirectory(legacy_test) if(WITH_MKLDNN) add_subdirectory(mkldnn) diff --git a/test/indexing/CMakeLists.txt b/test/indexing/CMakeLists.txt new file mode 100644 index 0000000000000..95739040ef4af --- /dev/null +++ b/test/indexing/CMakeLists.txt @@ -0,0 +1,9 @@ +file( + GLOB TEST_OPS + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) +endforeach() diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py new file mode 100644 index 0000000000000..3f96e2048467c --- /dev/null +++ b/test/indexing/test_setitem.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.fluid.variable_index import _setitem_static + + +class TestSetitemInDygraph(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_combined_index_1(self): + np_data = np.zeros((3, 4, 5, 6), dtype='float32') + x = paddle.to_tensor(np_data) + + np_data[[0, 1], :, [1, 2]] = 10.0 + x[[0, 1], :, [1, 2]] = 10.0 + + np.testing.assert_allclose(x.numpy(), np_data) + + def test_combined_index_2(self): + np_data = np.ones((3, 4, 5, 6), dtype='float32') + x = paddle.to_tensor(np_data) + + np_data[:, 1, [1, 2], 0] = 10.0 + x[:, 1, [1, 2], 0] = 10.0 + + np.testing.assert_allclose(x.numpy(), np_data) + + def test_combined_index_3(self): + np_data = np.ones((3, 4, 5, 6), dtype='int32') + x = paddle.to_tensor(np_data) + + np_data[:, [True, False, True, False], [1, 4]] = 10 + x[:, [True, False, True, False], [1, 4]] = 10 + + np.testing.assert_allclose(x.numpy(), np_data) + + +class TestSetitemInStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.exe = paddle.static.Executor() + + def test_combined_index_1(self): + np_data = np.zeros((3, 4, 5, 6), dtype='float32') + np_data[[0, 1], :, [1, 2]] = 10.0 + + x = paddle.zeros((3, 4, 5, 6), dtype='float32') + y = _setitem_static(x, ([0, 1], slice(None, None, None), [1, 2]), 10.0) + + program = paddle.static.default_startup_program() + res = self.exe.run(program, fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_data) + + def test_combined_index_2(self): + np_data = np.ones((3, 4, 5, 6), dtype='float32') + np_data[:, 1, [1, 2], 0] = 10.0 + + x = paddle.ones((3, 4, 5, 6), dtype='float32') + y = _setitem_static(x, (slice(None, None, None), 1, [1, 2], 0), 10.0) + + program = paddle.static.default_startup_program() + res = self.exe.run(program, fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_data) + + def test_combined_index_3(self): + np_data = np.ones((3, 4, 5, 6), dtype='int32') + np_data[:, [True, False, True, False], [1, 4]] = 10 + + x = paddle.ones((3, 4, 5, 6), dtype='int32') + y = _setitem_static( + x, (slice(None, None, None), [True, False, True, False], [1, 4]), 10 + ) + + program = paddle.static.default_startup_program() + res = self.exe.run(program, fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_data) From 7e03261af900969e4c3ee0fa3def7fc934ae806e Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Fri, 16 Jun 2023 06:24:45 +0000 Subject: [PATCH 06/27] lazy import --- python/paddle/fluid/variable_index.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 663873d7cb5f7..4678c8e4af2fe 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -18,7 +18,7 @@ from . import core import paddle import warnings -from .framework import default_main_program, Variable + MAX_INTEGER = 2**31 - 1 @@ -371,6 +371,7 @@ def _getitem_impl_(var, item): Returns: Sliced variable """ + from .framework import default_main_program, Variable if isinstance(item, list): if not is_one_dim_list(item, int): @@ -645,6 +646,7 @@ def _setitem_for_tensor_array(var, item, value): def _setitem_impl_(var, item, value): from paddle.fluid import core + from .framework import default_main_program, Variable if var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: return _setitem_for_tensor_array(var, item, value) @@ -980,6 +982,8 @@ def _setitem_static(x, indices, values): indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched. values(Tensor|Number|Ndarray): values to be assigned to the x. """ + from .framework import default_main_program, Variable + if x.type == paddle.fluid.core.VarDesc.VarType.LOD_TENSOR_ARRAY: return _setitem_for_tensor_array(x, indices, values) @@ -1170,6 +1174,8 @@ def get_tensor_with_basic_indexing( attrs['decrease_axis'], ) else: + from .framework import default_main_program + target_block = default_main_program().current_block() slice_out_var = target_block.create_var( From e47ed101cbcda8f832f71160d93f282d7da77d62 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Fri, 16 Jun 2023 10:04:50 +0000 Subject: [PATCH 07/27] fix some setitem error --- python/paddle/fluid/variable_index.py | 33 ++++++++++--------- test/indexing/test_setitem.py | 46 +++++++++++++++------------ 2 files changed, 42 insertions(+), 37 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 4678c8e4af2fe..186ae4b5e0b88 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -875,12 +875,12 @@ def deal_advanced_index(ori_tensor, indices, with_transback=False): if indice is None: transed_dim.append(i) transed_tensor = ori_tensor.transpose(transed_dim) - trans_back_dim = np.argsort(transed_dim) if with_transback else [] + trans_back_dim = np.argsort(transed_dim).tolist() if with_transback else [] return transed_tensor, transed_index, trans_back_dim -def parse_index(indices): - advanced_index = [] # content is (dim, index) +def parse_index(x, indices): + advanced_index = [None] * len(x.shape) # content is (dim, index) # for set_value / slice / strided_slice OP decrease_axes = [] axes = [] @@ -890,13 +890,13 @@ def parse_index(indices): use_strided_slice = False has_advanced_index = False + if not isinstance(indices, tuple): + indices = (indices,) + indices = replace_ndarray(indices) indices = replace_ellipsis(x, indices) indices, none_axes = replace_none(indices) - if not isinstance(indices, tuple): - indices = (indices,) - for dim, slice_item in enumerate(indices): start, end, step = None, None, None if is_integer_or_scalar_tensor(slice_item): @@ -905,11 +905,10 @@ def parse_index(indices): start = slice_item step = 1 end = slice_item + 1 if slice_item != -1 else MAX_INTEGER - advanced_index.append(None) elif isinstance(slice_item, bool): # single bool is advanced-indexing none_axes.append(dim) - advanced_index.append((dim, paddle.to_tensor(slice_item))) + advanced_index[dim] = (dim, paddle.to_tensor(slice_item)) has_advanced_index = True elif isinstance(slice_item, slice): start = slice_item.start @@ -919,13 +918,15 @@ def parse_index(indices): if start is None and end is None and step is None: continue - if not isinstance(step, Variable) and step == 0: + if not isinstance(step, paddle.fluid.Variable) and step == 0: raise ValueError( "When assign a value to a paddle.Tensor, step can not be 0, " "but received step is {}.".format(step) ) - if isinstance(step, Variable) and (start is None or end is None): + if isinstance(step, paddle.fluid.Variable) and ( + start is None or end is None + ): raise ValueError( "When assign a value to a paddle.Tensor, it's not supported that " "the start or end is None when the type of step is paddle.Tensor." @@ -937,14 +938,12 @@ def parse_index(indices): end = MAX_INTEGER if step > 0 else -1 step = 1 if step is None else step - advanced_index.append(None) - elif isinstance(slice_item, (list, tuple)): - advanced_index.append((dim, paddle.to_tensor(slice_item))) + advanced_index[dim] = (dim, paddle.to_tensor(slice_item)) has_advanced_index = True elif isinstance(slice_item, paddle.fluid.Variable): # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. - advanced_index.append((dim, slice_item)) + advanced_index[dim] = (dim, slice_item) has_advanced_index = True else: raise IndexError( @@ -998,7 +997,7 @@ def _setitem_static(x, indices, values): advanced_index, has_advanced_index, use_strided_slice, - ) = parse_index(indices) + ) = parse_index(x, indices) inputs = {'Input': x} attrs = { @@ -1086,7 +1085,7 @@ def _setitem_static(x, indices, values): transed_sub_tensor, adjusted_advanced_index, transback_dim, - ) = deal_advanced_index(sub_tensor, advanced_index) + ) = deal_advanced_index(sub_tensor, advanced_index, True) if not isinstance(values, Variable): values = paddle.assign(values) transed_sub_tensor = transed_sub_tensor.index_put( @@ -1235,7 +1234,7 @@ def _getitem_static(x, indices): advanced_index, has_advanced_index, use_strided_slice, - ) = parse_index(indices) + ) = parse_index(x, indices) # step2: Dealing with basic indexing out = get_tensor_with_basic_indexing( diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py index 3f96e2048467c..5c6671c18248b 100644 --- a/test/indexing/test_setitem.py +++ b/test/indexing/test_setitem.py @@ -60,37 +60,43 @@ def setUp(self): def test_combined_index_1(self): np_data = np.zeros((3, 4, 5, 6), dtype='float32') np_data[[0, 1], :, [1, 2]] = 10.0 - - x = paddle.zeros((3, 4, 5, 6), dtype='float32') - y = _setitem_static(x, ([0, 1], slice(None, None, None), [1, 2]), 10.0) - - program = paddle.static.default_startup_program() - res = self.exe.run(program, fetch_list=[y.name]) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.zeros((3, 4, 5, 6), dtype='float32') + y = _setitem_static( + x, ([0, 1], slice(None, None, None), [1, 2]), 10.0 + ) + res = self.exe.run(fetch_list=[y.name]) np.testing.assert_allclose(res[0], np_data) def test_combined_index_2(self): np_data = np.ones((3, 4, 5, 6), dtype='float32') np_data[:, 1, [1, 2], 0] = 10.0 - - x = paddle.ones((3, 4, 5, 6), dtype='float32') - y = _setitem_static(x, (slice(None, None, None), 1, [1, 2], 0), 10.0) - - program = paddle.static.default_startup_program() - res = self.exe.run(program, fetch_list=[y.name]) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='float32') + y = _setitem_static( + x, (slice(None, None, None), 1, [1, 2], 0), 10.0 + ) + res = self.exe.run(fetch_list=[y.name]) np.testing.assert_allclose(res[0], np_data) def test_combined_index_3(self): np_data = np.ones((3, 4, 5, 6), dtype='int32') np_data[:, [True, False, True, False], [1, 4]] = 10 - - x = paddle.ones((3, 4, 5, 6), dtype='int32') - y = _setitem_static( - x, (slice(None, None, None), [True, False, True, False], [1, 4]), 10 - ) - - program = paddle.static.default_startup_program() - res = self.exe.run(program, fetch_list=[y.name]) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='int32') + y = _setitem_static( + x, + (slice(None, None, None), [True, False, True, False], [1, 4]), + 10, + ) + res = self.exe.run(fetch_list=[y.name]) np.testing.assert_allclose(res[0], np_data) From d1e3fc24c34d69b33de7c735f90a6dadea9bcd13 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Sun, 25 Jun 2023 08:05:21 +0000 Subject: [PATCH 08/27] fix advance indexing with decreasing axes; fix strided_slice input name --- python/paddle/fluid/variable_index.py | 27 +++++++++++++------ test/indexing/test_setitem.py | 37 +++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 186ae4b5e0b88..68da806fb2f60 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -871,8 +871,8 @@ def deal_advanced_index(ori_tensor, indices, with_transback=False): if indice is not None: transed_dim.append(i) transed_index.append(indice[1]) - for i, indice in enumerate(indices): - if indice is None: + for i in range(ori_tensor.ndim): + if indices[i] is None: transed_dim.append(i) transed_tensor = ori_tensor.transpose(transed_dim) trans_back_dim = np.argsort(transed_dim).tolist() if with_transback else [] @@ -880,7 +880,7 @@ def deal_advanced_index(ori_tensor, indices, with_transback=False): def parse_index(x, indices): - advanced_index = [None] * len(x.shape) # content is (dim, index) + advanced_index = [None] * 2 * len(x.shape) # content is (dim, index) # for set_value / slice / strided_slice OP decrease_axes = [] axes = [] @@ -897,6 +897,7 @@ def parse_index(x, indices): indices = replace_ellipsis(x, indices) indices, none_axes = replace_none(indices) + estimated_dim = 0 for dim, slice_item in enumerate(indices): start, end, step = None, None, None if is_integer_or_scalar_tensor(slice_item): @@ -908,12 +909,17 @@ def parse_index(x, indices): elif isinstance(slice_item, bool): # single bool is advanced-indexing none_axes.append(dim) - advanced_index[dim] = (dim, paddle.to_tensor(slice_item)) + estimated_dim += 1 + advanced_index[estimated_dim] = ( + estimated_dim, + paddle.to_tensor(slice_item), + ) has_advanced_index = True elif isinstance(slice_item, slice): start = slice_item.start end = slice_item.stop step = slice_item.step + estimated_dim += 1 if start is None and end is None and step is None: continue @@ -939,12 +945,17 @@ def parse_index(x, indices): step = 1 if step is None else step elif isinstance(slice_item, (list, tuple)): - advanced_index[dim] = (dim, paddle.to_tensor(slice_item)) + advanced_index[estimated_dim] = ( + estimated_dim, + paddle.to_tensor(slice_item), + ) has_advanced_index = True + estimated_dim += 1 elif isinstance(slice_item, paddle.fluid.Variable): # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. - advanced_index[dim] = (dim, slice_item) + advanced_index[estimated_dim] = (estimated_dim, slice_item) has_advanced_index = True + estimated_dim += 1 else: raise IndexError( "Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {}.".format( @@ -1087,7 +1098,7 @@ def _setitem_static(x, indices, values): transback_dim, ) = deal_advanced_index(sub_tensor, advanced_index, True) if not isinstance(values, Variable): - values = paddle.assign(values) + values = paddle.assign(values).astype(transed_sub_tensor.dtype) transed_sub_tensor = transed_sub_tensor.index_put( adjusted_advanced_index, values ) @@ -1131,7 +1142,7 @@ def get_tensor_with_basic_indexing( out = x else: op_type = "strided_slice" if use_strided_slice else "slice" - inputs = {'X': [x]} if use_strided_slice else {'Input': [x]} + inputs = {'Input': [x]} attrs = { 'axes': axes, 'starts': [], diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py index 5c6671c18248b..f412247339fb6 100644 --- a/test/indexing/test_setitem.py +++ b/test/indexing/test_setitem.py @@ -58,6 +58,7 @@ def setUp(self): self.exe = paddle.static.Executor() def test_combined_index_1(self): + # int tensor + slice (without decreasing axes) np_data = np.zeros((3, 4, 5, 6), dtype='float32') np_data[[0, 1], :, [1, 2]] = 10.0 with paddle.static.program_guard( @@ -72,6 +73,7 @@ def test_combined_index_1(self): np.testing.assert_allclose(res[0], np_data) def test_combined_index_2(self): + # int tensor + slice (with decreasing axes) np_data = np.ones((3, 4, 5, 6), dtype='float32') np_data[:, 1, [1, 2], 0] = 10.0 with paddle.static.program_guard( @@ -86,6 +88,7 @@ def test_combined_index_2(self): np.testing.assert_allclose(res[0], np_data) def test_combined_index_3(self): + # int tensor + bool tensor + slice (without decreasing axes) np_data = np.ones((3, 4, 5, 6), dtype='int32') np_data[:, [True, False, True, False], [1, 4]] = 10 with paddle.static.program_guard( @@ -100,3 +103,37 @@ def test_combined_index_3(self): res = self.exe.run(fetch_list=[y.name]) np.testing.assert_allclose(res[0], np_data) + + def test_combined_index_4(self): + # int tensor (with ranks > 1) + bool tensor + slice (with decreasing axes) + np_data = np.ones((3, 4, 5, 6), dtype='int32') + np_data[[0, 0], [True, False, True, False], [[0, 2], [1, 4]], 4] = 16 + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='int32') + y = _setitem_static( + x, + ([0, 0], [True, False, True, False], [[0, 2], [1, 4]], 4), + 16, + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_data) + + def test_combined_index_5(self): + # int tensor + slice + Ellipsis + np_data = np.ones((3, 4, 5, 6), dtype='int32') + np_data[..., [1, 4, 3], ::2] = 5 + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='int32') + y = _setitem_static( + x, + (..., [1, 4, 3], slice(None, None, 2)), + 5, + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_data) From bb022b29c20da4b8907ee850161f1e9932cc2d14 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Mon, 26 Jun 2023 06:57:55 +0000 Subject: [PATCH 09/27] combine int-tensor getitem is ok (without boolean support & broadcast); add getitem unittest for static --- python/paddle/fluid/variable_index.py | 60 ++++++- test/indexing/test_getitem.py | 237 ++++++++++++++++++++++++++ 2 files changed, 290 insertions(+), 7 deletions(-) create mode 100644 test/indexing/test_getitem.py diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 68da806fb2f60..621f67b8c24b6 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -861,22 +861,53 @@ def idx_not_empty(var, item, value): return var -def deal_advanced_index(ori_tensor, indices, with_transback=False): +def deal_advanced_index(ori_tensor, indices, is_for_setitem): """ - Transpose origin Tensor and indices to the front. + Transpose origin Tensor and advanced indices to the front. + + Returns: + transed_tensor (Tensor): transposed tensor, corresbonding with advanced indices + transed_index (List): advanced indices transed to the front + trans_back_dim (List): order of axes to transpose back to original order. Only used in __setitem__. + pos_of_new_dim (int): axis of new dim in the result. Only used in __getitem__. + rank_of_new_dim (int): rank of new dim in the result. Only used in __getitem__. """ transed_dim = [] transed_index = [] + + # These flags indicates whether the result get by gather_nd requires a second transpose. + # Only used in __getitem__. + pos_of_new_dim = MAX_INTEGER + rank_of_new_dim = 1 + for i, indice in enumerate(indices): if indice is not None: + if not is_for_setitem: + if i == 0: + # case 1: advanced indices at axis 0, the new dim will be at first. + pos_of_new_dim = 0 + if i > 0 and len(transed_dim) > 0 and transed_dim[-1] != i - 1: + # case 2: there are not adjacent advanced indices, the new dim will be at first. + pos_of_new_dim = 0 + else: + pos_of_new_dim = min(pos_of_new_dim, i) + rank_of_new_dim = max(rank_of_new_dim, indice[1].ndim) transed_dim.append(i) transed_index.append(indice[1]) for i in range(ori_tensor.ndim): if indices[i] is None: transed_dim.append(i) transed_tensor = ori_tensor.transpose(transed_dim) - trans_back_dim = np.argsort(transed_dim).tolist() if with_transback else [] - return transed_tensor, transed_index, trans_back_dim + + trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else [] + + return ( + transed_tensor, + transed_index, + trans_back_dim, + pos_of_new_dim, + rank_of_new_dim, + ) def parse_index(x, indices): @@ -1096,6 +1127,8 @@ def _setitem_static(x, indices, values): transed_sub_tensor, adjusted_advanced_index, transback_dim, + _, + _, ) = deal_advanced_index(sub_tensor, advanced_index, True) if not isinstance(values, Variable): values = paddle.assign(values).astype(transed_sub_tensor.dtype) @@ -1261,11 +1294,24 @@ def _getitem_static(x, indices): # step3: Dealing with advanced indexing if has_advanced_index: - transed_tensor, adjusted_advanced_index, _ = deal_advanced_index( - out, advanced_index - ) + ( + transed_tensor, + adjusted_advanced_index, + _, + pos_of_new_dim, + rank_of_new_dim, + ) = deal_advanced_index(out, advanced_index, False) # TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently + adjusted_advanced_index = paddle.stack(adjusted_advanced_index, axis=-1) out = paddle.gather_nd(transed_tensor, adjusted_advanced_index) + if pos_of_new_dim != 0: + perm = ( + list(range(pos_of_new_dim, pos_of_new_dim + rank_of_new_dim)) + + list(range(0, pos_of_new_dim)) + + list(range(pos_of_new_dim + rank_of_new_dim, out.ndim)) + ) + out = out.transpose(perm) + return out diff --git a/test/indexing/test_getitem.py b/test/indexing/test_getitem.py new file mode 100644 index 0000000000000..6afb351346b4d --- /dev/null +++ b/test/indexing/test_getitem.py @@ -0,0 +1,237 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.fluid.variable_index import _getitem_static + +# class TestGetitemInDygraph(unittest.TestCase): +# def setUp(self): +# paddle.disable_static() + +# def test_combined_index_1(self): +# np_data = np.zeros((3, 4, 5, 6), dtype='float32') +# x = paddle.to_tensor(np_data) + +# np_res = np_data[[0, 1], :, [1, 2]] +# y = x[[0, 1], :, [1, 2]] + +# np.testing.assert_allclose(y.numpy(), np_res) + +# def test_combined_index_2(self): +# np_data = np.ones((3, 4, 5, 6), dtype='float32') +# x = paddle.to_tensor(np_data) + +# np_res = np_data[:, 1, [1, 2], 0] +# y = x[:, 1, [1, 2], 0] + +# np.testing.assert_allclose(y.numpy(), np_res) + +# def test_combined_index_3(self): +# np_data = np.ones((3, 4, 5, 6), dtype='int32') +# x = paddle.to_tensor(np_data) + +# np_res = np_data[:, [True, False, True, False], [1, 4]] +# y = x[:, [True, False, True, False], [1, 4]] + +# np.testing.assert_allclose(y.numpy(), np_res) + + +class TestGetitemInStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.exe = paddle.static.Executor() + + def test_combined_index_1(self): + # int tensor + slice (without decreasing axes) + np_data = np.random.randn(3, 4, 5, 6) + np_res = np_data[[0, 1], :, [1, 2]] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static(x, ([0, 1], slice(None, None, None), [1, 2])) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_2(self): + # int tensor + slice (with decreasing axes) + np_data = np.random.randn(3, 4, 5, 6) + np_res = np_data[:, 1, [1, 2], 0] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static(x, (slice(None, None, None), 1, [1, 2], 0)) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_3(self): + # multiple int tensors, with one int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[[1, 0], :, [1, 4], 1:5:2, 4] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, ([1, 0], slice(None, None, None), [1, 4], slice(1, 5, 2), 4) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_4(self): + # multiple not adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[:, [1, 0], 0:4:2, [2, 3], 4] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, (slice(None, None, None), [1, 0], slice(0, 4, 2), [2, 3], 4) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_5(self): + # multiple adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [1, 0], [2, 3], 0:4:2] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, (slice(None, None, 2), [1, 0], [2, 3], slice(0, 4, 2)) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_6(self): + # multiple adjacent and not adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [1, 0], [2, 3], 0:4:2, [4, 6]] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, + (slice(None, None, 2), [1, 0], [2, 3], slice(0, 4, 2), [4, 6]), + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_7(self): + # multiple adjacent and not adjacent int tensors (rank > 1d), with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, + ( + slice(None, None, 2), + [[1, 0]], + [[2, 3]], + slice(0, 4, 2), + [[4, 6]], + ), + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_8(self): + # multiple adjacent and not adjacent int tensors (rank > 1d), with int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[ + [[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]] + ] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, + ( + [[1, 0], [0, 1]], + [[2, 3], [1, 0]], + slice(0, 4, 2), + [[3, 5], [4, 2]], + ), + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + # def test_combined_index_7(self): + # # int tensor + bool tensor + slice (without decreasing axes) + # np_data = np.random.randn(3, 4, 5, 6) + # np_res = np_data[:, [True, False, True, False], [1, 4]] + # with paddle.static.program_guard( + # paddle.static.Program(), paddle.static.Program() + # ): + # x = paddle.to_tensor(np_data) + # y = _getitem_static( + # x, + # (slice(None, None, None), [True, False, True, False], [1, 4]) + # ) + # res = self.exe.run(fetch_list=[y.name]) + + # np.testing.assert_allclose(res[0], np_res) + + # def test_combined_index_4(self): + # # int tensor (with ranks > 1) + bool tensor + slice (with decreasing axes) + # np_data = np.arange(3*4*5*6).reshape((3, 4, 5, 6)) + # np_res = np_data[[0, 0], [True, False, True, False], [[0, 2], [1, 4]], 4] + # with paddle.static.program_guard( + # paddle.static.Program(), paddle.static.Program() + # ): + # x = paddle.to_tensor(np_data) + # y = _getitem_static( + # x, + # ([0, 0], [True, False, True, False], [[0, 2], [1, 4]], 4) + # ) + # res = self.exe.run(fetch_list=[y.name]) + + # np.testing.assert_allclose(res[0], np_res) + + # def test_combined_index_5(self): + # # int tensor + slice + Ellipsis + # np_data = np.arange(3*4*5*6).reshape((3, 4, 5, 6)) + # np_res = np_data[..., [1, 4, 3], ::2] + # with paddle.static.program_guard( + # paddle.static.Program(), paddle.static.Program() + # ): + # x = paddle.to_tensor(np_data) + # y = _getitem_static( + # x, + # (..., [1, 4, 3], slice(None, None, 2)), + # ) + # res = self.exe.run(fetch_list=[y.name]) + + # np.testing.assert_allclose(res[0], np_res) From e25a4bc405bda2addba9ebd3efcb946f56af5dcc Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 28 Jun 2023 12:41:41 +0000 Subject: [PATCH 10/27] add broadcast & parse bool tensor for __getitem --- python/paddle/fluid/variable_index.py | 52 ++++++++++++- test/indexing/test_getitem.py | 106 ++++++++++++++------------ 2 files changed, 109 insertions(+), 49 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 621f67b8c24b6..5afacef3d3e16 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -980,13 +980,36 @@ def parse_index(x, indices): estimated_dim, paddle.to_tensor(slice_item), ) + + if ( + advanced_index[estimated_dim][1].dtype == paddle.bool + and len(slice_item) != x.shape[dim] + ): + raise IndexError( + "The shape of boolean index {} did not match indexed tensor {} along axis {}".format( + len(slice_item), x.shape[dim], dim + ) + ) + has_advanced_index = True estimated_dim += 1 + elif isinstance(slice_item, paddle.fluid.Variable): # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. + if ( + slice_item.dtype == paddle.bool + and len(slice_item) != x.shape[dim] + ): + raise IndexError( + "The shape of boolean index {} did not match indexed tensor {} along axis {}".format( + len(slice_item), x.shape[dim], dim + ) + ) + advanced_index[estimated_dim] = (estimated_dim, slice_item) has_advanced_index = True estimated_dim += 1 + else: raise IndexError( "Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {}.".format( @@ -1303,8 +1326,22 @@ def _getitem_static(x, indices): ) = deal_advanced_index(out, advanced_index, False) # TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently - adjusted_advanced_index = paddle.stack(adjusted_advanced_index, axis=-1) - out = paddle.gather_nd(transed_tensor, adjusted_advanced_index) + if ( + len(adjusted_advanced_index) == 1 + and adjusted_advanced_index[0].dtype == paddle.bool + ): + # Note: now slice not support 0-size Tensor, so only one bool tensor can return empty 0-size. + out = get_value_for_bool_tensor( + transed_tensor, adjusted_advanced_index[0] + ) + else: + adjusted_advanced_index = parse_bool_and_broadcast_indices( + adjusted_advanced_index + ) + advanced_index_tensor = paddle.stack( + adjusted_advanced_index, axis=-1 + ) + out = paddle.gather_nd(transed_tensor, advanced_index_tensor) if pos_of_new_dim != 0: perm = ( @@ -1315,3 +1352,14 @@ def _getitem_static(x, indices): out = out.transpose(perm) return out + + +def parse_bool_and_broadcast_indices(indices): + # deal with multiple Tensors and translating bool tensor to int tensor. + # In static mode, bool-tensor cannot be broadcasted since its corressponding int tensor's shape cannot be infered. + for i, indice in enumerate(indices): + if indice.dtype == paddle.bool: + indices[i] = paddle.nonzero(indice)[:, 0] + if len(indices) > 1: + indices = paddle.broadcast_tensors(indices) + return indices diff --git a/test/indexing/test_getitem.py b/test/indexing/test_getitem.py index 6afb351346b4d..a8516f73d7e57 100644 --- a/test/indexing/test_getitem.py +++ b/test/indexing/test_getitem.py @@ -188,50 +188,62 @@ def test_combined_index_8(self): np.testing.assert_allclose(res[0], np_res) - # def test_combined_index_7(self): - # # int tensor + bool tensor + slice (without decreasing axes) - # np_data = np.random.randn(3, 4, 5, 6) - # np_res = np_data[:, [True, False, True, False], [1, 4]] - # with paddle.static.program_guard( - # paddle.static.Program(), paddle.static.Program() - # ): - # x = paddle.to_tensor(np_data) - # y = _getitem_static( - # x, - # (slice(None, None, None), [True, False, True, False], [1, 4]) - # ) - # res = self.exe.run(fetch_list=[y.name]) - - # np.testing.assert_allclose(res[0], np_res) - - # def test_combined_index_4(self): - # # int tensor (with ranks > 1) + bool tensor + slice (with decreasing axes) - # np_data = np.arange(3*4*5*6).reshape((3, 4, 5, 6)) - # np_res = np_data[[0, 0], [True, False, True, False], [[0, 2], [1, 4]], 4] - # with paddle.static.program_guard( - # paddle.static.Program(), paddle.static.Program() - # ): - # x = paddle.to_tensor(np_data) - # y = _getitem_static( - # x, - # ([0, 0], [True, False, True, False], [[0, 2], [1, 4]], 4) - # ) - # res = self.exe.run(fetch_list=[y.name]) - - # np.testing.assert_allclose(res[0], np_res) - - # def test_combined_index_5(self): - # # int tensor + slice + Ellipsis - # np_data = np.arange(3*4*5*6).reshape((3, 4, 5, 6)) - # np_res = np_data[..., [1, 4, 3], ::2] - # with paddle.static.program_guard( - # paddle.static.Program(), paddle.static.Program() - # ): - # x = paddle.to_tensor(np_data) - # y = _getitem_static( - # x, - # (..., [1, 4, 3], slice(None, None, 2)), - # ) - # res = self.exe.run(fetch_list=[y.name]) - - # np.testing.assert_allclose(res[0], np_res) + def test_combined_index_9(self): + # multiple int tensors, with broadcast. + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, ([[1, 0]], [1, 0], slice(0, 4, 2), [[3, 5], [4, 2]]) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_10(self): + # only one bool tensor with basic-index + np_data = np.random.randn(3, 4, 5, 6) + np_res = np_data[:, [True, False, True, False], 4] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, (slice(None, None, None), [True, False, True, False], 4) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + def test_combined_index_11(self): + # only one bool tensor with all False + np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_res = np_data[:, [False, False, False, False], 4] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, (slice(None, None, None), [False, False, False, False], 4) + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + + +class TestGetItemErrorCase(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_bool_shape_error1(self): + x = paddle.randn((4, 3, 2)) + with self.assertRaises(IndexError): + y = _getitem_static(x, ([True, False])) + + def test_bool_shape_error2(self): + x = paddle.randn((4, 3, 2)) + with self.assertRaises(IndexError): + y = _getitem_static(x, (1, paddle.to_tensor([True, False]), [0, 1])) From edf3b0c4f0f6f9fdbaccef293faa10bcdc31f538 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Mon, 3 Jul 2023 07:22:30 +0000 Subject: [PATCH 11/27] [change getitem] _getitem_impl_ to _getitem_static, not deleting the former one --- python/paddle/fluid/dygraph/tensor_patch_methods.py | 4 ++-- python/paddle/fluid/framework.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/dygraph/tensor_patch_methods.py b/python/paddle/fluid/dygraph/tensor_patch_methods.py index 58f99b9e98a74..151ea646ecdf6 100644 --- a/python/paddle/fluid/dygraph/tensor_patch_methods.py +++ b/python/paddle/fluid/dygraph/tensor_patch_methods.py @@ -26,7 +26,7 @@ from ..framework import ( Variable, Parameter, - _getitem_impl_, + _getitem_static, _setitem_impl_, EagerParamBase, in_dygraph_mode, @@ -740,7 +740,7 @@ def _is_list_tuple(item): if contain_tensor(item) or is_list_tuple(item, int): # 1. Call _getitem_impl_ when item contains tensor. # Why not call a c++ function ? Because item can't be parsed when it contains tensor. - return _getitem_impl_(self, item) + return _getitem_static(self, item) else: # 2. Call c++ func getitem_index_not_tensor to speedup. diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 38b62736e58bb..d1e09621e7aaa 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -37,7 +37,7 @@ import paddle.version as fluid_version import warnings import functools -from .variable_index import _getitem_impl_, _setitem_impl_ +from .variable_index import _getitem_static, _setitem_impl_ import threading __all__ = [ @@ -2290,7 +2290,7 @@ def _sliceAndConcatVar(self, item, axis): raise IndexError("Valid index accept int or slice or tuple") def __getitem__(self, item): - return _getitem_impl_(self, item) + return _getitem_static(self, item) def __setitem__(self, item, value): return _setitem_impl_(self, item, value) From 5b9e48a48ab3b82ab6e88b23d49f7e203e7a3e08 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Tue, 4 Jul 2023 03:27:01 +0000 Subject: [PATCH 12/27] refine new getitem; fix ut in variable/var_base --- python/paddle/fluid/variable_index.py | 25 ++++++++++++++++--------- test/legacy_test/test_variable.py | 10 ---------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 5afacef3d3e16..bd00f81317511 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -921,6 +921,9 @@ def parse_index(x, indices): use_strided_slice = False has_advanced_index = False + if isinstance(indices, list) and not is_one_dim_list(indices, int): + indices = tuple(indices) + if not isinstance(indices, tuple): indices = (indices,) @@ -955,6 +958,7 @@ def parse_index(x, indices): if start is None and end is None and step is None: continue + step = 1 if step is None else step if not isinstance(step, paddle.fluid.Variable) and step == 0: raise ValueError( "When assign a value to a paddle.Tensor, step can not be 0, " @@ -996,16 +1000,17 @@ def parse_index(x, indices): elif isinstance(slice_item, paddle.fluid.Variable): # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. - if ( - slice_item.dtype == paddle.bool - and len(slice_item) != x.shape[dim] - ): - raise IndexError( - "The shape of boolean index {} did not match indexed tensor {} along axis {}".format( - len(slice_item), x.shape[dim], dim - ) - ) + if slice_item.dtype == paddle.bool: + if slice_item.ndim == 0: + # 0-D bool Tensor, same as single PY-bool. + none_axes.append(dim) + elif slice_item.shape[0] != x.shape[dim]: + raise IndexError( + "The shape of boolean index {} did not match indexed tensor {} along axis {}".format( + slice_item.shape[0], x.shape[dim], dim + ) + ) advanced_index[estimated_dim] = (estimated_dim, slice_item) has_advanced_index = True estimated_dim += 1 @@ -1228,6 +1233,8 @@ def get_tensor_with_basic_indexing( end = attrs['ends'] if "StridesTensorList" in inputs.keys(): stride = inputs['StridesTensorList'] + else: + stride = attrs['strides'] if use_strided_slice: out = paddle._C_ops.strided_slice(x, axes, st, end, stride) else: diff --git a/test/legacy_test/test_variable.py b/test/legacy_test/test_variable.py index cbac9dc1849e8..5630552c18d07 100644 --- a/test/legacy_test/test_variable.py +++ b/test/legacy_test/test_variable.py @@ -257,10 +257,6 @@ def _test_slice_index_tensor(self, place): self.assertTrue((result[2] == expected[2]).all()) self.assertTrue((result[3] == expected[3]).all()) - with self.assertRaises(IndexError): - one = paddle.ones(shape=[1]) - res = x[one, [0, 0]] - def _test_slice_index_list(self, place): data = np.random.rand(2, 3).astype("float32") prog = paddle.static.Program() @@ -323,9 +319,6 @@ def _test_slice_index_ellipsis(self, place): self.assertTrue((result[5] == expected[5]).all()) self.assertTrue((result[6] == expected[6]).all()) - with self.assertRaises(IndexError): - res = x[[1.2, 0]] - def _test_slice_index_list_bool(self, place): data = np.random.rand(2, 3, 4).astype("float32") np_idx = np.array([[True, False, False], [True, False, True]]) @@ -375,9 +368,6 @@ def _test_slice_index_list_bool(self, place): with self.assertRaises(IndexError): res = x[[True, False, False]] - with self.assertRaises(ValueError): - with paddle.static.program_guard(prog): - res = x[[False, False]] def _test_slice_index_scalar_bool(self, place): data = np.random.rand(1, 3, 4).astype("float32") From 87cf34f504e8b96569ce40f030c8ff5d39ec06c8 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Tue, 4 Jul 2023 07:16:29 +0000 Subject: [PATCH 13/27] add __getitem__ ut in dygraph --- test/indexing/test_getitem.py | 121 ++++++++++++++++++++++++++++------ 1 file changed, 100 insertions(+), 21 deletions(-) diff --git a/test/indexing/test_getitem.py b/test/indexing/test_getitem.py index a8516f73d7e57..bb18a0b772373 100644 --- a/test/indexing/test_getitem.py +++ b/test/indexing/test_getitem.py @@ -19,36 +19,115 @@ import paddle from paddle.fluid.variable_index import _getitem_static -# class TestGetitemInDygraph(unittest.TestCase): -# def setUp(self): -# paddle.disable_static() -# def test_combined_index_1(self): -# np_data = np.zeros((3, 4, 5, 6), dtype='float32') -# x = paddle.to_tensor(np_data) +class TestGetitemInDygraph(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_combined_index_1(self): + # int tensor + slice (without decreasing axes) + np_data = np.random.randn(3, 4, 5, 6) + np_res = np_data[[0, 1], :, [1, 2]] + x = paddle.to_tensor(np_data) + y = x[[0, 1], :, [1, 2]] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_2(self): + # int tensor + slice (with decreasing axes) + np_data = np.random.randn(3, 4, 5, 6) + x = paddle.to_tensor(np_data) + + np_res = np_data[:, 1, [1, 2], 0] + y = x[:, 1, [1, 2], 0] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_3(self): + # multiple int tensors, with one int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[[1, 0], :, [1, 4], 1:5:2, 4] + + x = paddle.to_tensor(np_data) + y = x[[1, 0], :, [1, 4], 1:5:2, 4] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_4(self): + # multiple not adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[:, [1, 0], 0:4:2, [2, 3], 4] + x = paddle.to_tensor(np_data) + y = x[:, [1, 0], 0:4:2, [2, 3], 4] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_5(self): + # multiple adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [1, 0], [2, 3], 0:4:2] + x = paddle.to_tensor(np_data) + y = x[::2, [1, 0], [2, 3], 0:4:2] -# np_res = np_data[[0, 1], :, [1, 2]] -# y = x[[0, 1], :, [1, 2]] + np.testing.assert_allclose(y.numpy(), np_res) -# np.testing.assert_allclose(y.numpy(), np_res) + def test_combined_index_6(self): + # multiple adjacent and not adjacent int tensors, with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [1, 0], [2, 3], 0:4:2, [4, 6]] + x = paddle.to_tensor(np_data) + y = x[::2, [1, 0], [2, 3], 0:4:2, [4, 6]] -# def test_combined_index_2(self): -# np_data = np.ones((3, 4, 5, 6), dtype='float32') -# x = paddle.to_tensor(np_data) + np.testing.assert_allclose(y.numpy(), np_res) -# np_res = np_data[:, 1, [1, 2], 0] -# y = x[:, 1, [1, 2], 0] + def test_combined_index_7(self): + # multiple adjacent and not adjacent int tensors (rank > 1d), with no int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]] + x = paddle.to_tensor(np_data) + y = x[::2, [[1, 0]], [[2, 3]], 0:4:2, [[4, 6]]] -# np.testing.assert_allclose(y.numpy(), np_res) + np.testing.assert_allclose(y.numpy(), np_res) -# def test_combined_index_3(self): -# np_data = np.ones((3, 4, 5, 6), dtype='int32') -# x = paddle.to_tensor(np_data) + def test_combined_index_8(self): + # multiple adjacent and not adjacent int tensors (rank > 1d), with int tensor at first axis + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[ + [[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]] + ] + x = paddle.to_tensor(np_data) + y = x[[[1, 0], [0, 1]], [[2, 3], [1, 0]], 0:4:2, [[3, 5], [4, 2]]] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_9(self): + # multiple int tensors, with broadcast. + np_data = np.random.randn(3, 4, 5, 6, 7) + np_res = np_data[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]] + x = paddle.to_tensor(np_data) + y = x[[[1, 0]], [1, 0], 0:4:2, [[3, 5], [4, 2]]] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_10(self): + # only one bool tensor with basic-index + np_data = np.random.randn(3, 4, 5, 6) + np_res = np_data[:, [True, False, True, False], 4] + + x = paddle.to_tensor(np_data) + y = x[:, [True, False, True, False], 4] + + np.testing.assert_allclose(y.numpy(), np_res) + + def test_combined_index_11(self): + # only one bool tensor with all False + np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_res = np_data[:, [False, False, False, False], 4] -# np_res = np_data[:, [True, False, True, False], [1, 4]] -# y = x[:, [True, False, True, False], [1, 4]] + x = paddle.to_tensor(np_data) + y = x[:, [False, False, False, False], 4] -# np.testing.assert_allclose(y.numpy(), np_res) + np.testing.assert_allclose(y.numpy(), np_res) class TestGetitemInStatic(unittest.TestCase): From a206dbfe8d3b9865dffae52237dfb4c85eac3c60 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 5 Jul 2023 06:44:31 +0000 Subject: [PATCH 14/27] re-dispatch getitem for Py/CPP; fix strided_slice decrease axes error in dygraph --- paddle/fluid/pybind/eager_method.cc | 3 ++ .../fluid/dygraph/tensor_patch_methods.py | 45 +++++++------------ python/paddle/fluid/variable_index.py | 2 + 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index e78c443e33a5e..99fad3e9d1187 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -915,6 +915,9 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, eager_gil_scoped_release guard; out = strided_slice_ad_func( self->tensor, slice_axes, slice_starts, slice_ends, slice_strides); + if (!decrease_axis_tmp.empty()) { + out = squeeze_ad_func(out, decrease_axis_tmp); + } } else { PADDLE_THROW(platform::errors::InvalidArgument( "Slice is only support slice and strided_slice, but we got %s which " diff --git a/python/paddle/fluid/dygraph/tensor_patch_methods.py b/python/paddle/fluid/dygraph/tensor_patch_methods.py index 151ea646ecdf6..a9436470d48b2 100644 --- a/python/paddle/fluid/dygraph/tensor_patch_methods.py +++ b/python/paddle/fluid/dygraph/tensor_patch_methods.py @@ -718,26 +718,25 @@ def contain_tensor(item): return True return False - def __getitem__(self, item): - def is_list_tuple(index, contain_type): - def _is_list_tuple(item): - if isinstance(item, (tuple, list)): - for s in item: - if not _is_list_tuple(s): - return False - else: - if type(item) != contain_type: - return False + def contain_tensor_or_list(item): + if not isinstance(item, tuple): + item = (item,) + + for slice_item in item: + if isinstance(slice_item, (list, np.ndarray, Variable)): return True + elif isinstance(slice_item, slice): + if ( + isinstance(slice_item.start, Variable) + or isinstance(slice_item.stop, Variable) + or isinstance(slice_item.step, Variable) + ): + return True - if not isinstance(index, (tuple, list)): - return False - for s in index: - if not _is_list_tuple(s): - return False - return True + return False - if contain_tensor(item) or is_list_tuple(item, int): + def __getitem__(self, item): + if contain_tensor_or_list(item): # 1. Call _getitem_impl_ when item contains tensor. # Why not call a c++ function ? Because item can't be parsed when it contains tensor. return _getitem_static(self, item) @@ -747,18 +746,6 @@ def _is_list_tuple(item): return self._getitem_index_not_tensor(item) def __setitem__(self, item, value): - def contain_tensor_or_list(item): - if not isinstance(item, tuple): - item = [item] - - for slice_item in item: - if isinstance(slice_item, list): - return True - elif isinstance(slice_item, Variable): - return True - - return False - def is_combine_index(item): var_type = None item_type = None diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index bd00f81317511..0089b7389e521 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -1237,6 +1237,8 @@ def get_tensor_with_basic_indexing( stride = attrs['strides'] if use_strided_slice: out = paddle._C_ops.strided_slice(x, axes, st, end, stride) + if len(decrease_axes) > 0: + out = paddle._C_ops.squeeze(out, decrease_axes) else: out = paddle._C_ops.slice( x, From ed8f20cacafc015fb96e48d7aaf6c8aadc5e4acd Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 5 Jul 2023 07:28:10 +0000 Subject: [PATCH 15/27] fix ut; support tensor in slice --- python/paddle/fluid/variable_index.py | 21 +++++---------------- test/legacy_test/test_var_base.py | 6 ++---- 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 0089b7389e521..8b4811b3b1ece 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -959,25 +959,10 @@ def parse_index(x, indices): continue step = 1 if step is None else step - if not isinstance(step, paddle.fluid.Variable) and step == 0: - raise ValueError( - "When assign a value to a paddle.Tensor, step can not be 0, " - "but received step is {}.".format(step) - ) - - if isinstance(step, paddle.fluid.Variable) and ( - start is None or end is None - ): - raise ValueError( - "When assign a value to a paddle.Tensor, it's not supported that " - "the start or end is None when the type of step is paddle.Tensor." - ) - if start is None: start = 0 if step > 0 else MAX_INTEGER if end is None: end = MAX_INTEGER if step > 0 else -1 - step = 1 if step is None else step elif isinstance(slice_item, (list, tuple)): advanced_index[estimated_dim] = ( @@ -1026,7 +1011,11 @@ def parse_index(x, indices): ends.append(end) steps.append(step) axes.append(dim) - use_strided_slice = True if step != 1 else use_strided_slice + use_strided_slice = ( + True + if (isinstance(step, paddle.fluid.Variable) or step != 1) + else use_strided_slice + ) return ( starts, ends, diff --git a/test/legacy_test/test_var_base.py b/test/legacy_test/test_var_base.py index f96f25281d800..426a949e51ce3 100644 --- a/test/legacy_test/test_var_base.py +++ b/test/legacy_test/test_var_base.py @@ -944,11 +944,9 @@ def _test_bool_index(self): var_tensor[var_tensor < 0.55], np_value[np_value < 0.55] ) - with self.assertRaises(ValueError): - var_tensor[[False, False, False, False]] - with self.assertRaises(ValueError): + with self.assertRaises(IndexError): var_tensor[[True, False]] - with self.assertRaises(ValueError): + with self.assertRaises(IndexError): var_tensor[[True, False, False, False, False]] with self.assertRaises(IndexError): var_tensor[paddle.to_tensor([[True, False, False, False]])] From 2136c676e673a70967012fdd51a7df4cb28a3ce9 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 5 Jul 2023 08:09:21 +0000 Subject: [PATCH 16/27] [change setitem] _setitem_impl_ to _setitem_static, not deleting the former one --- .../fluid/dygraph/tensor_patch_methods.py | 29 +++---------------- python/paddle/fluid/framework.py | 4 +-- python/paddle/fluid/variable_index.py | 20 +++---------- 3 files changed, 10 insertions(+), 43 deletions(-) diff --git a/python/paddle/fluid/dygraph/tensor_patch_methods.py b/python/paddle/fluid/dygraph/tensor_patch_methods.py index a9436470d48b2..428b8794024a3 100644 --- a/python/paddle/fluid/dygraph/tensor_patch_methods.py +++ b/python/paddle/fluid/dygraph/tensor_patch_methods.py @@ -27,7 +27,7 @@ Variable, Parameter, _getitem_static, - _setitem_impl_, + _setitem_static, EagerParamBase, in_dygraph_mode, ) @@ -746,31 +746,10 @@ def __getitem__(self, item): return self._getitem_index_not_tensor(item) def __setitem__(self, item, value): - def is_combine_index(item): - var_type = None - item_type = None - if isinstance(item, (tuple, list)): - for slice_item in item: - if item_type is None: - item_type = type(slice_item) - else: - if type(slice_item) != item_type: - return True - - if isinstance(slice_item, Variable): - if var_type is None: - var_type = slice_item.dtype - else: - if var_type != slice_item.dtype: - return True - return False - - return False - - if contain_tensor_or_list(item) and not is_combine_index(item): + if contain_tensor_or_list(item): # To reuse code with static graph, - # Call _setitem_impl_ when item contains tensor or list. - return _setitem_impl_(self, item, value) + # Call _setitem_static when item contains tensor or list. + return _setitem_static(self, item, value) else: return self.__setitem_eager_tensor__(item, value) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index d1e09621e7aaa..5cda62265567f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -37,7 +37,7 @@ import paddle.version as fluid_version import warnings import functools -from .variable_index import _getitem_static, _setitem_impl_ +from .variable_index import _getitem_static, _setitem_static import threading __all__ = [ @@ -2293,7 +2293,7 @@ def __getitem__(self, item): return _getitem_static(self, item) def __setitem__(self, item, value): - return _setitem_impl_(self, item, value) + return _setitem_static(self, item, value) def get_value(self, scope=None): """ diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 8b4811b3b1ece..e547f6db991fe 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -1108,21 +1108,15 @@ def _setitem_static(x, indices, values): # step3.1: Only basic indexing, use OP set_value to set value. if paddle.in_dynamic_mode(): x._bump_inplace_version() - out = x - else: - helper = paddle.fluid.layer_helper.LayerHelper( - 'set_value', **locals() - ) - out = helper.create_variable_for_type_inference(dtype=dtype) cur_block = default_main_program().current_block() cur_block.append_op( type="set_value", inputs=inputs, - outputs={'Out': out}, + outputs={'Out': x}, attrs=attrs, inplace_map={"Input": "Out"}, ) - return out + return x else: # step3.2: Case for there are advanced indexing. # 1. get __getitem__ result of basic indexing; @@ -1168,21 +1162,15 @@ def _setitem_static(x, indices, values): inputs["ValueTensor"] = transback_sub_tensor if paddle.in_dynamic_mode(): x._bump_inplace_version() - out = x - else: - helper = paddle.fluid.layer_helper.LayerHelper( - 'set_value', **locals() - ) - out = helper.create_variable_for_type_inference(dtype=x.dtype) cur_block = default_main_program().current_block() cur_block.append_op( type="set_value", inputs=inputs, - outputs={'Out': out}, + outputs={'Out': x}, attrs=attrs, inplace_map={"Input": "Out"}, ) - return out + return x def get_tensor_with_basic_indexing( From 8f298809c12b04879d3d17257503dbe142960079 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 6 Jul 2023 08:45:02 +0000 Subject: [PATCH 17/27] remove some UT (for some, temporarily) --- python/paddle/fluid/variable_index.py | 2 +- test/legacy_test/test_set_value_op.py | 35 ++++++++++++--------------- test/legacy_test/test_var_base.py | 5 ++-- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index e547f6db991fe..5b4883faf867b 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -1099,7 +1099,7 @@ def _setitem_static(x, indices, values): inputs["ValueTensor"] = values else: raise TypeError( - "Only support to assign an Number, numpy.ndarray or " + "Only support to assign an integer, float, numpy.ndarray or " "paddle.Tensor to a paddle.Tensor, but received {}".format( type(values) ) diff --git a/test/legacy_test/test_set_value_op.py b/test/legacy_test/test_set_value_op.py index b4d5c1d02eff5..0f9c0b06f605b 100644 --- a/test/legacy_test/test_set_value_op.py +++ b/test/legacy_test/test_set_value_op.py @@ -438,12 +438,12 @@ def _get_answer(self): self.data[[True, False]] = self.value -class TestSetValueItemBool2(TestSetValueApi): - def _call_setitem(self, x): - x[[False, False]] = self.value +# class TestSetValueItemBool2(TestSetValueApi): +# def _call_setitem(self, x): +# x[[False, False]] = self.value - def _get_answer(self): - self.data[[False, False]] = self.value +# def _get_answer(self): +# self.data[[False, False]] = self.value class TestSetValueItemBool3(TestSetValueApi): @@ -463,17 +463,17 @@ def _get_answer(self): self.data[np.array([False, True])] = np.zeros(self.shape[2]) -class TestSetValueItemBool5(TestSetValueApi): - def _call_setitem(self, x): - idx = paddle.assign( - np.array([[False, True, False], [True, True, False]]) - ) - x[idx] = self.value +# class TestSetValueItemBool5(TestSetValueApi): +# def _call_setitem(self, x): +# idx = paddle.assign( +# np.array([[False, True, False], [True, True, False]]) +# ) +# x[idx] = self.value - def _get_answer(self): - self.data[ - np.array([[False, True, False], [True, True, False]]) - ] = self.value +# def _get_answer(self): +# self.data[ +# np.array([[False, True, False], [True, True, False]]) +# ] = self.value class TestSetValueItemBool6(TestSetValueApi): @@ -1057,10 +1057,6 @@ def _ellipsis_error(self): x[::one] = self.value def _bool_list_error(self): - with self.assertRaises(TypeError): - x = paddle.ones(shape=self.shape, dtype=self.dtype) - x[[True, False, 0]] = 0 - with self.assertRaises(IndexError): x = paddle.ones(shape=self.shape, dtype=self.dtype) x[[True, False], [True, False]] = 0 @@ -1085,7 +1081,6 @@ def test_error(self): paddle.enable_static() with paddle.static.program_guard(self.program): self._value_type_error() - self._step_error() self._bool_list_error() self._bool_tensor_error() self._broadcast_mismatch() diff --git a/test/legacy_test/test_var_base.py b/test/legacy_test/test_var_base.py index 426a949e51ce3..cfe0e5c718c1e 100644 --- a/test/legacy_test/test_var_base.py +++ b/test/legacy_test/test_var_base.py @@ -1368,7 +1368,6 @@ def _test(self, value): id_origin = id(self.tensor_x) index_1 = paddle.to_tensor(np.array([True, False, False, False])) self.tensor_x[index_1] = value - self.assertEqual(self.tensor_x.inplace_version, 1) if isinstance(value, (int, float)): result = np.zeros((2, 3)).astype(self.dtype) + value @@ -1381,13 +1380,13 @@ def _test(self, value): index_2 = paddle.to_tensor(np.array([False, True, False, False])) self.tensor_x[index_2] = value - self.assertEqual(self.tensor_x.inplace_version, 2) + np.testing.assert_array_equal(self.tensor_x[1].numpy(), result) self.assertEqual(id_origin, id(self.tensor_x)) index_3 = paddle.to_tensor(np.array([True, True, True, True])) self.tensor_x[index_3] = value - self.assertEqual(self.tensor_x.inplace_version, 3) + np.testing.assert_array_equal(self.tensor_x[3].numpy(), result) self.assertEqual(id_origin, id(self.tensor_x)) From d6f9a2c30b166d3e800097db6dcf002b7e44b929 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Tue, 11 Jul 2023 03:49:29 +0000 Subject: [PATCH 18/27] add IndexError to solve timeout problem in static-mode --- python/paddle/fluid/variable_index.py | 22 ++++++++++++++++++++++ test/legacy_test/test_while_loop_op.py | 4 ++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 9812e4435777b..f1b841b35e964 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -929,10 +929,32 @@ def parse_index(x, indices): indices = replace_ellipsis(x, indices) indices, none_axes = replace_none(indices) + is_tensor_array = ( + hasattr(x, "desc") + and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY + ) + estimated_dim = 0 for dim, slice_item in enumerate(indices): start, end, step = None, None, None if is_integer_or_scalar_tensor(slice_item): + if ( + not is_tensor_array + and isinstance(slice_item, int) + and x.shape[dim] is not None + and x.shape[dim] >= 0 + and slice_item >= x.shape[dim] + ): + # For python, if users write a, b = var, the __getitem__ + # method will iterate through 0, 1, 2 ... until __getitem__ + # throws an IndexError, then stop. The var[0], var[1] will + # be given to a, b respectively. If more values are given, + # the unpack size would cause error. + # We raises IndexError here to support grammar like `a, b = var` + raise IndexError( + "slice_item %d at dim %d should be >= 0 and < x.shape[%d]: %d" + % (slice_item, dim, dim, x.shape[dim]) + ) # not calculate result to reduce call times for slice OP. decrease_axes.append(dim) start = slice_item diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index d06b9f3e5042f..9ba690f5b1d93 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -655,9 +655,9 @@ def body(z, i): startup_program = Program() with program_guard(main_program, startup_program): x = paddle.static.data(name='x', shape=[-1, 5], dtype='int32') - z = paddle.tensor.fill_constant([1], 'int32', 0) + z = paddle.tensor.fill_constant([], 'int32', 0) x_shape = paddle.shape(x) - i = paddle.tensor.fill_constant([1], 'int32', 0) + i = paddle.tensor.fill_constant([], 'int32', 0) z, _ = paddle.static.nn.while_loop(cond, body, [z, i]) place = ( From aec33806f760346271a2d8f513e893f5d4694ecf Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 19 Jul 2023 07:22:03 +0000 Subject: [PATCH 19/27] 1.temply forbideen all-False bool-indexput; 2.setitem_static will return new variable --- python/paddle/fluid/variable_index.py | 26 ++++++++++++++++++++++---- test/legacy_test/test_set_value_op.py | 16 ++++++++-------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 244e5e7ced6b1..b053aa623e7a7 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -1152,15 +1152,27 @@ def _setitem_static(x, indices, values): # step3.1: Only basic indexing, use OP set_value to set value. if paddle.in_dynamic_mode(): x._bump_inplace_version() + output = x + else: + helper = paddle.fluid.layer_helper.LayerHelper( + 'set_value', **locals() + ) + output = helper.create_variable_for_type_inference(dtype=x.dtype) cur_block = default_main_program().current_block() cur_block.append_op( type="set_value", inputs=inputs, - outputs={'Out': x}, + outputs={'Out': output}, attrs=attrs, inplace_map={"Input": "Out"}, ) - return x + + if not paddle.in_dynamic_mode(): + # map var to the new output + paddle.jit.api.ProgramTranslator.get_instance()._params_map.add( + cur_block.program, x.desc.id(), output + ) + return output else: # step3.2: Case for there are advanced indexing. # 1. get __getitem__ result of basic indexing; @@ -1206,15 +1218,21 @@ def _setitem_static(x, indices, values): inputs["ValueTensor"] = transback_sub_tensor if paddle.in_dynamic_mode(): x._bump_inplace_version() + output = x + else: + helper = paddle.fluid.layer_helper.LayerHelper( + 'set_value', **locals() + ) + output = helper.create_variable_for_type_inference(dtype=x.dtype) cur_block = default_main_program().current_block() cur_block.append_op( type="set_value", inputs=inputs, - outputs={'Out': x}, + outputs={'Out': output}, attrs=attrs, inplace_map={"Input": "Out"}, ) - return x + return output def get_tensor_with_basic_indexing( diff --git a/test/legacy_test/test_set_value_op.py b/test/legacy_test/test_set_value_op.py index 9f797e6ab0ac3..494dd95e38769 100644 --- a/test/legacy_test/test_set_value_op.py +++ b/test/legacy_test/test_set_value_op.py @@ -642,16 +642,16 @@ def _get_answer(self): self.data[[True, False]] = self.value -class TestSetValueItemBool2(TestSetValueApi): - def _call_setitem(self, x): - x[[False, False]] = self.value +# class TestSetValueItemBool2(TestSetValueApi): +# def _call_setitem(self, x): +# x[[False, False]] = self.value - def _call_setitem_static_api(self, x): - x = paddle.static.setitem(x, [False, False], self.value) - return x +# def _call_setitem_static_api(self, x): +# x = paddle.static.setitem(x, [False, False], self.value) +# return x - def _get_answer(self): - self.data[[False, False]] = self.value +# def _get_answer(self): +# self.data[[False, False]] = self.value class TestSetValueItemBool3(TestSetValueApi): From 3b62bc71bac1f59247ea93a7bd3f81424084e486 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 19 Jul 2023 08:31:09 +0000 Subject: [PATCH 20/27] xpu uses old stratege --- .../fluid/dygraph/tensor_patch_methods.py | 25 +++++++++++++++++++ python/paddle/fluid/framework.py | 5 +++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/tensor_patch_methods.py b/python/paddle/fluid/dygraph/tensor_patch_methods.py index 3aae5bec23cb2..ea53b9a6b1fef 100644 --- a/python/paddle/fluid/dygraph/tensor_patch_methods.py +++ b/python/paddle/fluid/dygraph/tensor_patch_methods.py @@ -28,6 +28,7 @@ Parameter, _getitem_static, _setitem_static, + _setitem_impl_, EagerParamBase, in_dygraph_mode, ) @@ -745,7 +746,31 @@ def __getitem__(self, item): return self._getitem_index_not_tensor(item) def __setitem__(self, item, value): + def is_combine_index(item): + var_type = None + item_type = None + if isinstance(item, (tuple, list)): + for slice_item in item: + if item_type is None: + item_type = type(slice_item) + else: + if type(slice_item) != item_type: + return True + + if isinstance(slice_item, Variable): + if var_type is None: + var_type = slice_item.dtype + else: + if var_type != slice_item.dtype: + return True + return False + + return False + if contain_tensor_or_list(item): + if core.is_compiled_with_xpu() and not is_combine_index(item): + # (NOTE): Currently, there is no index_put_xpu kernel. + return _setitem_impl_(self, item, value) # To reuse code with static graph, # Call _setitem_static when item contains tensor or list. return _setitem_static(self, item, value) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 1608fd2e78ed3..bcc8cc2b980e7 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -37,7 +37,7 @@ import paddle.version as fluid_version import warnings import functools -from .variable_index import _getitem_static, _setitem_static +from .variable_index import _getitem_static, _setitem_static, _setitem_impl_ import threading __all__ = [ @@ -2298,6 +2298,9 @@ def __setitem__(self, item, value): from .dygraph.base import in_declarative_mode if in_declarative_mode(): + if is_compiled_with_xpu(): + # (NOTE): Currently, there is no index_put_xpu kernel. + return _setitem_impl_(self, item, value) return _setitem_static(self, item, value) else: raise RuntimeError( From 7c656492e886242da833d78af28607c96d357fe2 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 19 Jul 2023 09:02:26 +0000 Subject: [PATCH 21/27] rename dy2st setitem ut to avoid same-name problem --- test/dygraph_to_static/{test_setitem.py => test_jit_setitem.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/dygraph_to_static/{test_setitem.py => test_jit_setitem.py} (100%) diff --git a/test/dygraph_to_static/test_setitem.py b/test/dygraph_to_static/test_jit_setitem.py similarity index 100% rename from test/dygraph_to_static/test_setitem.py rename to test/dygraph_to_static/test_jit_setitem.py From f41f4e85c6e1642ec466059ff49e1eb8e8a57f78 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 20 Jul 2023 03:57:38 +0000 Subject: [PATCH 22/27] dy2st for new combined index --- python/paddle/fluid/variable_index.py | 5 +++++ test/dygraph_to_static/test_jit_setitem.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index b053aa623e7a7..0c13a5c940afd 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -1232,6 +1232,11 @@ def _setitem_static(x, indices, values): attrs=attrs, inplace_map={"Input": "Out"}, ) + if not paddle.in_dynamic_mode(): + # map var to the new output + paddle.jit.api.ProgramTranslator.get_instance()._params_map.add( + cur_block.program, x.desc.id(), output + ) return output diff --git a/test/dygraph_to_static/test_jit_setitem.py b/test/dygraph_to_static/test_jit_setitem.py index 93b8c5d7936b4..7920acd3324ac 100644 --- a/test/dygraph_to_static/test_jit_setitem.py +++ b/test/dygraph_to_static/test_jit_setitem.py @@ -178,5 +178,24 @@ def run_dygrah(self, func): return y, x_grad, value_grad +# class TestCase12(TestSetItemBase): +# # Test combind-indexing +# def init_func(self): +# def foo(x, value): +# y = x + 1 +# y[[0,1], 1, :2] = value +# return y + +# return foo + +# def run_dygrah(self, func): +# x = self.init_data() +# value = paddle.ones((32,)) +# value.stop_gradient = False +# y = func(x, value) +# x_grad, value_grad = paddle.grad(y, [x, value]) +# return y, x_grad, value_grad + + if __name__ == '__main__': unittest.main() From ee4855eb057853e582af454ae566c8973ad06230 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Tue, 25 Jul 2023 04:02:41 +0000 Subject: [PATCH 23/27] ut case for combine-index with dy2st --- test/dygraph_to_static/test_jit_setitem.py | 34 +++++++++++----------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/test/dygraph_to_static/test_jit_setitem.py b/test/dygraph_to_static/test_jit_setitem.py index 7920acd3324ac..374d0569c5969 100644 --- a/test/dygraph_to_static/test_jit_setitem.py +++ b/test/dygraph_to_static/test_jit_setitem.py @@ -178,23 +178,23 @@ def run_dygrah(self, func): return y, x_grad, value_grad -# class TestCase12(TestSetItemBase): -# # Test combind-indexing -# def init_func(self): -# def foo(x, value): -# y = x + 1 -# y[[0,1], 1, :2] = value -# return y - -# return foo - -# def run_dygrah(self, func): -# x = self.init_data() -# value = paddle.ones((32,)) -# value.stop_gradient = False -# y = func(x, value) -# x_grad, value_grad = paddle.grad(y, [x, value]) -# return y, x_grad, value_grad +class TestCase12(TestSetItemBase): + # Test combind-indexing + def init_func(self): + def foo(x, value): + y = x + 1 + y[[0, 1], 1, :2] = value + return y + + return foo + + def run_dygrah(self, func): + x = self.init_data() + value = paddle.ones((32,)) + value.stop_gradient = False + y = func(x, value) + x_grad, value_grad = paddle.grad(y, [x, value]) + return y, x_grad, value_grad if __name__ == '__main__': From f7c6096713105ee1473a9951de9cf46b37e94f35 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 27 Jul 2023 06:09:24 +0000 Subject: [PATCH 24/27] open ut with all-false-bool setitem --- test/legacy_test/test_set_value_op.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/legacy_test/test_set_value_op.py b/test/legacy_test/test_set_value_op.py index 494dd95e38769..9f797e6ab0ac3 100644 --- a/test/legacy_test/test_set_value_op.py +++ b/test/legacy_test/test_set_value_op.py @@ -642,16 +642,16 @@ def _get_answer(self): self.data[[True, False]] = self.value -# class TestSetValueItemBool2(TestSetValueApi): -# def _call_setitem(self, x): -# x[[False, False]] = self.value +class TestSetValueItemBool2(TestSetValueApi): + def _call_setitem(self, x): + x[[False, False]] = self.value -# def _call_setitem_static_api(self, x): -# x = paddle.static.setitem(x, [False, False], self.value) -# return x + def _call_setitem_static_api(self, x): + x = paddle.static.setitem(x, [False, False], self.value) + return x -# def _get_answer(self): -# self.data[[False, False]] = self.value + def _get_answer(self): + self.data[[False, False]] = self.value class TestSetValueItemBool3(TestSetValueApi): From b400aa466986d6f7c5fc40afa78612d41008b30e Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Fri, 28 Jul 2023 03:38:12 +0000 Subject: [PATCH 25/27] remove useless doc and _getitem_impl_ --- python/paddle/fluid/variable_index.py | 252 -------------------------- 1 file changed, 252 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 0c13a5c940afd..785b367ec2e38 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -362,255 +362,6 @@ def idx_not_empty(var, item): ) -def _getitem_impl_(var, item): - """ - Slice the variable. - - Args: - item(int/slice/tuple) : the index. - - Returns: - Sliced variable - """ - from .framework import default_main_program, Variable - - if isinstance(item, list): - if not is_one_dim_list(item, int): - item = tuple(item) - - if not isinstance(item, tuple): - item = (item,) - - decrease_axes = [] - axes = [] - starts = [] - ends = [] - steps = [] - reverse_axes = [] - - use_strided_slice = False - item = replace_ndarray(item) - item = replace_ellipsis(var, item) - item, none_axes = replace_none(item) - slice_info = SliceInfo() - is_tensor_array = ( - hasattr(var, "desc") - and var.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY - ) - - for dim, slice_item in enumerate(item): - if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor( - slice_item - ): - if ( - not is_tensor_array - and isinstance(slice_item, int) - and var.shape[dim] is not None - and var.shape[dim] >= 0 - and slice_item >= var.shape[dim] - ): - # For python, if users write a, b = var, the __getitem__ - # method will iterate through 0, 1, 2 ... until __getitem__ - # throws an IndexError, then stop. The var[0], var[1] will - # be given to a, b respectively. If more values are given, - # the unpack size would cause error. - # We raises IndexError here to support grammar like `a, b = var` - raise IndexError( - "slice_item %d at dim %d should be >= 0 and < var.shape[%d]: %d" - % (slice_item, dim, dim, var.shape[dim]) - ) - decrease_axes.append(dim) - start = slice_item - step = 1 - end = slice_item + 1 if slice_item != -1 else MAX_INTEGER - - elif isinstance(slice_item, slice): - start = slice_item.start - end = slice_item.stop - step = slice_item.step - - if start is None and end is None and step is None: - continue - - step = 1 if step is None else step - - if start is None: - start = 0 if step > 0 else MAX_INTEGER - if end is None: - if ( - paddle.in_dynamic_mode() or not is_tensor_array - ) and var.shape[dim] != -1: - end = var.shape[dim] if step > 0 else -1 - else: - end = MAX_INTEGER if step > 0 else -1 - - elif isinstance(slice_item, list): - all_bool = True - - if is_list_tuple(slice_item, int): - slice_info.update(slice_item) - continue - - for i in slice_item: - if type(i) is int: - all_bool = False - elif not isinstance(i, bool): - raise TypeError("Only support int or bool in index list.") - - if len(item) != 1: - raise IndexError( - "When index contains a list, its length must be 1, but received {}.".format( - len(item) - ) - ) - new_slice_item = [] - if all_bool: - if len(slice_item) != var.shape[0]: - raise IndexError( - "The dimension of bool index doesn't match indexed array along " - "dimension 0, the target dimension is {}, but received {}.".format( - var.shape[0], len(slice_item) - ) - ) - for idx, ele in enumerate(slice_item): - if ele is True: - new_slice_item.append(idx) - slice_item = new_slice_item - else: - for idx, ele in enumerate(slice_item): - if type(ele) is int: - new_slice_item.append(ele) - elif ele is True: - new_slice_item.append(1) - else: - new_slice_item.append(0) - slice_item = new_slice_item - - from ..tensor import index_select - - idx = paddle.assign(np.array(slice_item).astype("int32")) - return index_select(var, index=idx, axis=0) - - elif isinstance(slice_item, (Variable, core.eager.Tensor)): - if len(item) == 1: - from ..tensor import index_select - - if slice_item.dtype == paddle.bool: - return get_value_for_bool_tensor(var, slice_item) - else: - if len(slice_item.shape) == 1: - return index_select(var, index=slice_item, axis=0) - else: - slice_info.update(slice_item) - continue - else: - slice_info.update(slice_item) - continue - - else: - raise IndexError( - "Valid index accept int or slice or ellipsis or list, but received {}.".format( - slice_item - ) - ) - - axes.append(dim) - starts.append(start) - ends.append(end) - steps.append(step) - use_strided_slice = True if step != 1 else use_strided_slice - - if slice_info.indexes: - if len(slice_info.indexes) != len(item): - raise IndexError( - "Valid index accept int or slice or ellipsis or list, but received {}.".format( - item - ) - ) - return slice_info.get_item(var) - - inputs = {'Input': [var]} - attrs = { - 'axes': axes, - 'starts': [], - 'ends': [], - 'decrease_axis': decrease_axes, - } - if use_strided_slice: - attrs['strides'] = [] - - infer_flags = [1] * len(axes) - deal_attrs(attrs, starts, "starts", "StartsTensorList", inputs, infer_flags) - deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags) - deal_attrs( - attrs, steps, "strides", "StridesTensorList", inputs, infer_flags - ) - attrs['infer_flags'] = infer_flags - - out = var - if len(axes) > 0: - op_type = "strided_slice" if use_strided_slice else "slice" - if paddle.in_dynamic_mode() and op_type == "slice": - if "StartsTensorList" in inputs.keys(): - st = inputs['StartsTensorList'] - else: - st = attrs['starts'] - if "EndsTensorList" in inputs.keys(): - end = inputs['EndsTensorList'] - else: - end = attrs['ends'] - out = paddle._C_ops.slice( - var, axes, st, end, attrs['infer_flags'], attrs['decrease_axis'] - ) - else: - target_block = default_main_program().current_block() - - slice_out_var = target_block.create_var( - name=unique_name.generate_with_ignorable_key( - var.name + "_" + op_type - ), - dtype=var.dtype, - ) - target_block.append_op( - type=op_type, - inputs=inputs, - outputs={'Out': [slice_out_var]}, - attrs=attrs, - ) - out = slice_out_var - - if len(reverse_axes) > 0: - from .layers.tensor import reverse - - out = reverse(out, axis=reverse_axes) - - # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D - # with FLAGS_set_to_1d=True. In this case, one `None` should be pop out, - # otherwise the output shape will be not correct. - set_to_1d = paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d'] - if set_to_1d and len(decrease_axes) == len(var.shape): - warnings.warn( - "Warning: In Tensor '__getitem__', if the number of scalar elements in the index is equal to the rank of the Tensor, the output should be 0-D. In order to be consistent with the behavior of previous versions, it will be processed to 1-D. But it is not correct and will be removed in release 2.6. If 1-D is still wanted, please modify the index element from scalar to slice (e.g. 'x[i]' => 'x[i:i+1]')." - ) - none_axes = none_axes[1:] - - if len(none_axes) > 0: - # Deal with cases that decrease_axes is not empty - # For example: - # # x.shape: (2,3,4) - # out = x[0, 0:2, None] # out.shape : (2, 1, 4) - for idx, axis in enumerate(none_axes): - l = len([i for i in decrease_axes if i < axis]) - new_axis = axis - l - none_axes[idx] = new_axis - - from ..tensor import unsqueeze - - out = unsqueeze(out, axis=none_axes) - - return out - - def _setitem_for_tensor_array(var, item, value): """branches for tensor array setitem operation. A item can be a: @@ -1075,7 +826,6 @@ def parse_index(x, indices): def _setitem_static(x, indices, values): """ - [WIP]: support __setitem__ by iteration strategy. combined indexing will be support by this. In dynamic mode, this function will modify the value at input tensor, returning same Tensor as input. But it will return a new Tensor with assigned value in static mode. @@ -1337,8 +1087,6 @@ def get_tensor_with_basic_indexing( def _getitem_static(x, indices): """ - [WIP]: support __getitem__ by iteration strategy. combined indexing will be support by this. - Args: x(Tensor): Tensor to be indexing. indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched. From beaf4402fd75b1fd4286ea26865c0698e688eb43 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Fri, 28 Jul 2023 11:46:43 +0000 Subject: [PATCH 26/27] change static res --- python/paddle/static/input.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/static/input.py b/python/paddle/static/input.py index ab8f80c8879aa..32b37e7c9fa4e 100644 --- a/python/paddle/static/input.py +++ b/python/paddle/static/input.py @@ -18,7 +18,7 @@ from paddle.fluid.framework import convert_np_dtype_to_dtype_, static_only from paddle.fluid.layer_helper import LayerHelper -from ..fluid.variable_index import _setitem_impl_ +from ..fluid.variable_index import _setitem_static __all__ = [] @@ -368,4 +368,4 @@ def setitem(x, index, value): (2) a[1] = v -> setitem(a, (1,), v) """ - return _setitem_impl_(x, index, value) + return _setitem_static(x, index, value) From 1a4ae458f7ad922cf5dfb258bde314c801ef1aff Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Mon, 31 Jul 2023 02:22:30 +0000 Subject: [PATCH 27/27] fix static xpu --- python/paddle/static/input.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/paddle/static/input.py b/python/paddle/static/input.py index 32b37e7c9fa4e..4f856227bda71 100644 --- a/python/paddle/static/input.py +++ b/python/paddle/static/input.py @@ -18,7 +18,7 @@ from paddle.fluid.framework import convert_np_dtype_to_dtype_, static_only from paddle.fluid.layer_helper import LayerHelper -from ..fluid.variable_index import _setitem_static +from ..fluid.variable_index import _setitem_impl_, _setitem_static __all__ = [] @@ -367,5 +367,8 @@ def setitem(x, index, value): (1) a[Tensor([10,10])]=v -> setitem(a, (Tensor([10,10]),), v) (2) a[1] = v -> setitem(a, (1,), v) """ - - return _setitem_static(x, index, value) + if core.is_compiled_with_xpu(): + # (NOTE): Currently, there is no index_put_xpu kernel. + return _setitem_impl_(x, index, value) + else: + return _setitem_static(x, index, value)