diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 7ae8d9db86..c05b4ddf0c 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -23,6 +23,7 @@ from . import _operations __all__ = [ + "argsort", "balance", "broadcast_arrays", "broadcast_to", @@ -63,6 +64,46 @@ ] +def argsort(a: DNDarray, axis: int = -1, descending: bool = False) -> DNDarray: + """ + Returns the indices that would sort an array. This is the distributed equivalent of `np.argsort`. + The sorting is not stable which means that equal elements in the result may have a different ordering than in the + original array. + Sorting where `axis==a.split` needs a lot of communication between the processes of MPI. + + Parameters + ---------- + a : DNDarray + Input array to be sorted. + axis : int, optional + The dimension to sort along. + Default is the last axis. + descending : bool, optional + If set to `True`, indices are sorted in descending order. + + Raises + ------ + ValueError + If `axis` is not consistent with the available dimensions. + + Examples + -------- + >>> x = ht.array([[4, 1], [2, 3]], split=0) + >>> x.shape + (1, 2) + (1, 2) + >>> y = ht.argsort(x, axis=0) + >>> y + (array([[1, 0], + [0, 1]])) + >>> ht.argsort(x, descending=True) + (array([[0, 1], + [1, 0]])) + """ + _, indices = sort(a=a, axis=axis, descending=descending, out=None) + return indices + + def balance(array: DNDarray, copy=False) -> DNDarray: """ Out of place balance function. More information on the meaning of balance can be found in diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 9825d333e9..b1f9470c5e 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -6,6 +6,118 @@ class TestManipulations(TestCase): + def test_argsort(self): + size = ht.MPI_WORLD.size + rank = ht.MPI_WORLD.rank + tensor = ( + torch.arange(size, device=self.device.torch_device).repeat(size).reshape(size, size) + ) + + data = ht.array(tensor, split=None) + result_indices = ht.argsort(data, axis=0, descending=True) + exp_indices = torch.argsort(tensor, dim=0, descending=True) + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + + result_indices = ht.argsort(data, axis=1, descending=True) + exp_indices = torch.argsort(tensor, dim=1, descending=True) + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + + data = ht.array(tensor, split=0) + + exp_indices = torch.tensor([[rank] * size], device=self.device.torch_device) + result_indices = ht.argsort(data, descending=True, axis=0) + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + + exp_indices = ( + torch.arange(size, device=self.device.torch_device) + .reshape(1, size) + .argsort(dim=1, descending=True) + ) + result_indices = ht.argsort(data, descending=True, axis=1) + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + + indices1 = ht.argsort(data, axis=1, descending=True) + indices2 = ht.argsort(data, descending=True) + self.assertTrue(ht.equal(indices1, indices2)) + + data = ht.array(tensor, split=1) + + indices_axis_zero = torch.arange( + size, dtype=torch.int64, device=self.device.torch_device + ).reshape(size, 1) + result_indices = ht.argsort(data, axis=0, descending=True) + # comparison value is only true on CPU + if result_indices.larray.is_cuda is False: + self.assertTrue(torch.equal(result_indices.larray, indices_axis_zero.int())) + + exp_axis_one = ( + torch.tensor(size - rank - 1, device=self.device.torch_device) + .repeat(size) + .reshape(size, 1) + ) + result_indices = ht.argsort(data, descending=True, axis=1) + self.assertTrue(torch.equal(result_indices.larray, exp_axis_one.int())) + + tensor = torch.tensor( + [ + [[2, 8, 5], [7, 2, 3]], + [[6, 5, 2], [1, 8, 7]], + [[9, 3, 0], [1, 2, 4]], + [[8, 4, 7], [0, 8, 9]], + ], + dtype=torch.int32, + device=self.device.torch_device, + ) + + data = ht.array(tensor, split=0) + if torch.cuda.is_available() and data.device == ht.gpu and size < 4: + indices_axis_zero = torch.tensor( + [[0, 2, 2], [3, 2, 0]], dtype=torch.int32, device=self.device.torch_device + ) + else: + indices_axis_zero = torch.tensor( + [[0, 2, 2], [3, 0, 0]], dtype=torch.int32, device=self.device.torch_device + ) + result_indices = ht.argsort(data, axis=0) + first_indices = result_indices[0].larray + if rank == 0: + self.assertTrue(torch.equal(first_indices, indices_axis_zero)) + + data = ht.array(tensor, split=1) + indices_axis_one = torch.tensor( + [[0, 1, 1]], dtype=torch.int32, device=self.device.torch_device + ) + result_indices = ht.argsort(data, axis=1) + first_indices = result_indices[0].larray[:1] + if rank == 0: + self.assertTrue(torch.equal(first_indices, indices_axis_one)) + + data = ht.array(tensor, split=2) + indices_axis_two = torch.tensor( + [[0], [1]], dtype=torch.int32, device=self.device.torch_device + ) + result_indices = ht.argsort(data, axis=2) + first_indices = result_indices[0].larray[:, :1] + if rank == 0: + self.assertTrue(torch.equal(first_indices, indices_axis_two)) + + # test exceptions + with self.assertRaises(ValueError): + ht.argsort(data, axis=3) + with self.assertRaises(TypeError): + ht.argsort(data, axis="1") + + rank = ht.MPI_WORLD.rank + ht.random.seed(1) + data = ht.random.randn(100, 1, split=0) + indices = ht.argsort(data, axis=0) + result = data[indices.larray.tolist()] + counts, _, _ = ht.get_comm().counts_displs_shape(data.gshape, axis=0) + for i, c in enumerate(counts): + for idx in range(c - 1): + if rank == i: + self.assertTrue(torch.lt(result.larray[idx], result.larray[idx + 1]).all()) + def test_broadcast_arrays(self): a = ht.array([[1], [2]]) b = ht.array([[0, 1]])