Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Combined indexing for __getitem__ and __setitem__ #55211

Merged
merged 35 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b663071
WIP: start writing combined indexing get
zoooo0820 May 25, 2023
1bcf667
list/tuple/Variable
zoooo0820 Jun 5, 2023
bee90bd
Merge branch 'develop' into combined_indexing
zoooo0820 Jun 6, 2023
2d00290
getitem 80%
zoooo0820 Jun 14, 2023
6dfd5be
add setitem
zoooo0820 Jun 15, 2023
1aeb04e
add some unittest for setitem
zoooo0820 Jun 15, 2023
7e03261
lazy import
zoooo0820 Jun 16, 2023
7ddd876
Merge branch 'develop' into combined_indexing
zoooo0820 Jun 16, 2023
e47ed10
fix some setitem error
zoooo0820 Jun 16, 2023
d1e3fc2
fix advance indexing with decreasing axes; fix strided_slice input name
zoooo0820 Jun 25, 2023
bb022b2
combine int-tensor getitem is ok (without boolean support & broadcast…
zoooo0820 Jun 26, 2023
e25a4bc
add broadcast & parse bool tensor for __getitem
zoooo0820 Jun 28, 2023
edf3b0c
[change getitem] _getitem_impl_ to _getitem_static, not deleting the …
zoooo0820 Jul 3, 2023
5b9e48a
refine new getitem; fix ut in variable/var_base
zoooo0820 Jul 4, 2023
87cf34f
add __getitem__ ut in dygraph
zoooo0820 Jul 4, 2023
a206dbf
re-dispatch getitem for Py/CPP; fix strided_slice decrease axes error…
zoooo0820 Jul 5, 2023
ed8f20c
fix ut; support tensor in slice
zoooo0820 Jul 5, 2023
2136c67
[change setitem] _setitem_impl_ to _setitem_static, not deleting the …
zoooo0820 Jul 5, 2023
8f29880
remove some UT (for some, temporarily)
zoooo0820 Jul 6, 2023
f1aba67
Merge branch 'develop' into combined_indexing
zoooo0820 Jul 6, 2023
d6f9a2c
add IndexError to solve timeout problem in static-mode
zoooo0820 Jul 11, 2023
c03bd95
merge dev & solve conflict
zoooo0820 Jul 19, 2023
aec3380
1.temply forbideen all-False bool-indexput; 2.setitem_static will ret…
zoooo0820 Jul 19, 2023
3b62bc7
xpu uses old stratege
zoooo0820 Jul 19, 2023
7c65649
rename dy2st setitem ut to avoid same-name problem
zoooo0820 Jul 19, 2023
f41f4e8
dy2st for new combined index
zoooo0820 Jul 20, 2023
ff3bd14
Merge branch 'develop' into combined_indexing
zoooo0820 Jul 25, 2023
ee4855e
ut case for combine-index with dy2st
zoooo0820 Jul 25, 2023
5e7f629
Merge branch 'develop' into combined_indexing
zoooo0820 Jul 26, 2023
f7c6096
open ut with all-false-bool setitem
zoooo0820 Jul 27, 2023
afa26c5
Merge branch 'develop' into combined_indexing
zoooo0820 Jul 27, 2023
b400aa4
remove useless doc and _getitem_impl_
zoooo0820 Jul 28, 2023
beaf440
change static res
zoooo0820 Jul 28, 2023
1a4ae45
fix static xpu
zoooo0820 Jul 31, 2023
84386c4
merge develop with stride
zoooo0820 Jul 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,9 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
eager_gil_scoped_release guard;
out = strided_slice_ad_func(
self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
if (!decrease_axis_tmp.empty()) {
out = squeeze_ad_func(out, decrease_axis_tmp);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Slice is only support slice and strided_slice, but we got %s which "
Expand Down
59 changes: 25 additions & 34 deletions python/paddle/fluid/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from ..framework import (
Variable,
Parameter,
_getitem_impl_,
_getitem_static,
_setitem_static,
_setitem_impl_,
EagerParamBase,
in_dygraph_mode,
Expand Down Expand Up @@ -726,47 +727,34 @@ def contain_tensor(item):
return True
return False

def __getitem__(self, item):
def is_list_tuple(index, contain_type):
def _is_list_tuple(item):
if isinstance(item, (tuple, list)):
for s in item:
if not _is_list_tuple(s):
return False
else:
if type(item) != contain_type:
return False
def contain_tensor_or_list(item):
if not isinstance(item, tuple):
item = (item,)

for slice_item in item:
if isinstance(slice_item, (list, np.ndarray, Variable)):
return True
elif isinstance(slice_item, slice):
if (
isinstance(slice_item.start, Variable)
or isinstance(slice_item.stop, Variable)
or isinstance(slice_item.step, Variable)
):
return True

if not isinstance(index, (tuple, list)):
return False
for s in index:
if not _is_list_tuple(s):
return False
return True
return False

if contain_tensor(item) or is_list_tuple(item, int):
def __getitem__(self, item):
if contain_tensor_or_list(item):
# 1. Call _getitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
return _getitem_impl_(self, item)
return _getitem_static(self, item)

else:
# 2. Call c++ func getitem_index_not_tensor to speedup.
return self._getitem_index_not_tensor(item)

def __setitem__(self, item, value):
def contain_tensor_or_list(item):
if not isinstance(item, tuple):
item = [item]

for slice_item in item:
if isinstance(slice_item, list):
return True
elif isinstance(slice_item, Variable):
return True

return False

def is_combine_index(item):
var_type = None
item_type = None
Expand All @@ -788,10 +776,13 @@ def is_combine_index(item):

return False

if contain_tensor_or_list(item) and not is_combine_index(item):
if contain_tensor_or_list(item):
if core.is_compiled_with_xpu() and not is_combine_index(item):
# (NOTE): Currently, there is no index_put_xpu kernel.
return _setitem_impl_(self, item, value)
# To reuse code with static graph,
# Call _setitem_impl_ when item contains tensor or list.
return _setitem_impl_(self, item, value)
# Call _setitem_static when item contains tensor or list.
return _setitem_static(self, item, value)

else:
return self.__setitem_eager_tensor__(item, value)
Expand Down
9 changes: 6 additions & 3 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import paddle.version as fluid_version
import warnings
import functools
from .variable_index import _getitem_impl_, _setitem_impl_
from .variable_index import _getitem_static, _setitem_static, _setitem_impl_
import threading

__all__ = [
Expand Down Expand Up @@ -2293,13 +2293,16 @@ def _sliceAndConcatVar(self, item, axis):
raise IndexError("Valid index accept int or slice or tuple")

def __getitem__(self, item):
return _getitem_impl_(self, item)
return _getitem_static(self, item)

def __setitem__(self, item, value):
from .dygraph.base import in_declarative_mode

if in_declarative_mode():
return _setitem_impl_(self, item, value)
if is_compiled_with_xpu():
# (NOTE): Currently, there is no index_put_xpu kernel.
return _setitem_impl_(self, item, value)
return _setitem_static(self, item, value)
else:
raise RuntimeError(
"In static mode, the __setitem__ (looks like: x[indices] = values) should not be used. Please use x = paddle.static.setitem(x, indices, values)"
Expand Down
Loading