Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added argsort function and tests #1224

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from . import _operations

__all__ = [
"argsort",
"balance",
"broadcast_arrays",
"broadcast_to",
Expand Down Expand Up @@ -63,6 +64,46 @@
]


def argsort(a: DNDarray, axis: int = -1, descending: bool = False) -> DNDarray:
"""
Returns the indices that would sort an array. This is the distributed equivalent of `np.argsort`.
The sorting is not stable which means that equal elements in the result may have a different ordering than in the
original array.
Sorting where `axis==a.split` needs a lot of communication between the processes of MPI.

Parameters
----------
a : DNDarray
Input array to be sorted.
axis : int, optional
The dimension to sort along.
Default is the last axis.
descending : bool, optional
If set to `True`, indices are sorted in descending order.

Raises
------
ValueError
If `axis` is not consistent with the available dimensions.

Examples
--------
>>> x = ht.array([[4, 1], [2, 3]], split=0)
>>> x.shape
(1, 2)
(1, 2)
>>> y = ht.argsort(x, axis=0)
>>> y
(array([[1, 0],
[0, 1]]))
>>> ht.argsort(x, descending=True)
(array([[0, 1],
[1, 0]]))
"""
_, indices = sort(a=a, axis=axis, descending=descending, out=None)
return indices


def balance(array: DNDarray, copy=False) -> DNDarray:
"""
Out of place balance function. More information on the meaning of balance can be found in
Expand Down
112 changes: 112 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,118 @@


class TestManipulations(TestCase):
def test_argsort(self):
size = ht.MPI_WORLD.size
rank = ht.MPI_WORLD.rank
tensor = (
torch.arange(size, device=self.device.torch_device).repeat(size).reshape(size, size)
)

data = ht.array(tensor, split=None)
result_indices = ht.argsort(data, axis=0, descending=True)
exp_indices = torch.argsort(tensor, dim=0, descending=True)
self.assertTrue(torch.equal(result_indices.larray, exp_indices.int()))

result_indices = ht.argsort(data, axis=1, descending=True)
exp_indices = torch.argsort(tensor, dim=1, descending=True)
self.assertTrue(torch.equal(result_indices.larray, exp_indices.int()))

data = ht.array(tensor, split=0)

exp_indices = torch.tensor([[rank] * size], device=self.device.torch_device)
result_indices = ht.argsort(data, descending=True, axis=0)
self.assertTrue(torch.equal(result_indices.larray, exp_indices.int()))

exp_indices = (
torch.arange(size, device=self.device.torch_device)
.reshape(1, size)
.argsort(dim=1, descending=True)
)
result_indices = ht.argsort(data, descending=True, axis=1)
self.assertTrue(torch.equal(result_indices.larray, exp_indices.int()))

indices1 = ht.argsort(data, axis=1, descending=True)
indices2 = ht.argsort(data, descending=True)
self.assertTrue(ht.equal(indices1, indices2))

data = ht.array(tensor, split=1)

indices_axis_zero = torch.arange(
size, dtype=torch.int64, device=self.device.torch_device
).reshape(size, 1)
result_indices = ht.argsort(data, axis=0, descending=True)
# comparison value is only true on CPU
if result_indices.larray.is_cuda is False:
self.assertTrue(torch.equal(result_indices.larray, indices_axis_zero.int()))

exp_axis_one = (
torch.tensor(size - rank - 1, device=self.device.torch_device)
.repeat(size)
.reshape(size, 1)
)
result_indices = ht.argsort(data, descending=True, axis=1)
self.assertTrue(torch.equal(result_indices.larray, exp_axis_one.int()))

tensor = torch.tensor(
[
[[2, 8, 5], [7, 2, 3]],
[[6, 5, 2], [1, 8, 7]],
[[9, 3, 0], [1, 2, 4]],
[[8, 4, 7], [0, 8, 9]],
],
dtype=torch.int32,
device=self.device.torch_device,
)

data = ht.array(tensor, split=0)
if torch.cuda.is_available() and data.device == ht.gpu and size < 4:
indices_axis_zero = torch.tensor(
[[0, 2, 2], [3, 2, 0]], dtype=torch.int32, device=self.device.torch_device
)
else:
indices_axis_zero = torch.tensor(
[[0, 2, 2], [3, 0, 0]], dtype=torch.int32, device=self.device.torch_device
)
result_indices = ht.argsort(data, axis=0)
first_indices = result_indices[0].larray
if rank == 0:
self.assertTrue(torch.equal(first_indices, indices_axis_zero))

data = ht.array(tensor, split=1)
indices_axis_one = torch.tensor(
[[0, 1, 1]], dtype=torch.int32, device=self.device.torch_device
)
result_indices = ht.argsort(data, axis=1)
first_indices = result_indices[0].larray[:1]
if rank == 0:
self.assertTrue(torch.equal(first_indices, indices_axis_one))

data = ht.array(tensor, split=2)
indices_axis_two = torch.tensor(
[[0], [1]], dtype=torch.int32, device=self.device.torch_device
)
result_indices = ht.argsort(data, axis=2)
first_indices = result_indices[0].larray[:, :1]
if rank == 0:
self.assertTrue(torch.equal(first_indices, indices_axis_two))

# test exceptions
with self.assertRaises(ValueError):
ht.argsort(data, axis=3)
with self.assertRaises(TypeError):
ht.argsort(data, axis="1")

rank = ht.MPI_WORLD.rank
ht.random.seed(1)
data = ht.random.randn(100, 1, split=0)
indices = ht.argsort(data, axis=0)
result = data[indices.larray.tolist()]
counts, _, _ = ht.get_comm().counts_displs_shape(data.gshape, axis=0)
for i, c in enumerate(counts):
for idx in range(c - 1):
if rank == i:
self.assertTrue(torch.lt(result.larray[idx], result.larray[idx + 1]).all())

def test_broadcast_arrays(self):
a = ht.array([[1], [2]])
b = ht.array([[0, 1]])
Expand Down
Loading