From c2b942971e311d6bf589a3f5672ba89cefcf678b Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Sun, 5 May 2024 14:05:08 -0400 Subject: [PATCH 01/17] Fix syntax error in test related to cupy (#9000) I suspect the CIs don't have cupy which meant that this line didn't get hit. Recreation: ``` mamba create --name xr_py10 python=3.10 --channel conda-forge --override-channels mamba activate xr_py10 pip install -e . -vv pip install pytest mamba install cupy ``` ``` pytest xarray/tests/test_array_api.py -x ``` Fails on my machine. Happy to provide more info --- xarray/core/duck_array_ops.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index d95dfa566cc..23be37618b0 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -233,9 +233,10 @@ def as_shared_dtype(scalars_or_arrays, xp=np): raise ValueError( f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}" ) - elif array_type_cupy := array_type("cupy") and any( # noqa: F841 - isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821 - ): + + # Avoid calling array_type("cupy") repeatidely in the any check + array_type_cupy = array_type("cupy") + if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays): import cupy as cp arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] From 8d728bf253f63c557ab8993d47c1cd0c3113277d Mon Sep 17 00:00:00 2001 From: ignamv Date: Sun, 5 May 2024 20:06:57 +0200 Subject: [PATCH 02/17] Add argument check_dims to assert_allclose to allow transposed inputs (#5733) (#8991) * Add argument check_dims to assert_allclose to allow transposed inputs * Update whats-new.rst * Add `check_dims` argument to assert_equal and assert_identical + tests * Assert that dimensions match before transposing or comparing values * Add docstring for check_dims to assert_equal and assert_identical * Update doc/whats-new.rst Co-authored-by: Tom Nicholas * Undo fat finger Co-authored-by: Tom Nicholas * Add attribution to whats-new.rst * Replace check_dims with bool argument check_dim_order, rename align_dims to maybe_transpose_dims * Remove left-over half-made test * Remove check_dim_order argument from assert_identical * assert_allclose/equal: emit full diff if dimensions don't match * Rename check_dim_order test, test Dataset with different dim orders * Update whats-new.rst * Hide maybe_transpose_dims from Pytest traceback Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * Ignore mypy error due to missing functools.partial.__name__ --------- Co-authored-by: Tom Nicholas Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- doc/whats-new.rst | 2 ++ xarray/testing/assertions.py | 29 +++++++++++++++++++++++++---- xarray/tests/test_assertions.py | 19 +++++++++++++++++++ 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0f79b648187..9a4601a776f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,8 @@ New Features for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray` then, such as broadcasting. By `Ilan Gold `_. +- :py:func:`testing.assert_allclose`/:py:func:`testing.assert_equal` now accept a new argument `check_dims="transpose"`, controlling whether a transposed array is considered equal. (:issue:`5733`, :pull:`8991`) + By `Ignacio Martinez Vazquez `_. - Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg `create_index=False`. (:pull:`8960`) By `Tom Nicholas `_. diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 018874c169e..69885868f83 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -95,6 +95,18 @@ def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False): raise TypeError(f"{type(a)} not of type DataTree") +def maybe_transpose_dims(a, b, check_dim_order: bool): + """Helper for assert_equal/allclose/identical""" + __tracebackhide__ = True + if not isinstance(a, (Variable, DataArray, Dataset)): + return b + if not check_dim_order and set(a.dims) == set(b.dims): + # Ensure transpose won't fail if a dimension is missing + # If this is the case, the difference will be caught by the caller + return b.transpose(*a.dims) + return b + + @overload def assert_equal(a, b): ... @@ -104,7 +116,7 @@ def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ... @ensure_warnings -def assert_equal(a, b, from_root=True): +def assert_equal(a, b, from_root=True, check_dim_order: bool = True): """Like :py:func:`numpy.testing.assert_array_equal`, but for xarray objects. @@ -127,6 +139,8 @@ def assert_equal(a, b, from_root=True): Only used when comparing DataTree objects. Indicates whether or not to first traverse to the root of the trees before checking for isomorphism. If a & b have no parents then this has no effect. + check_dim_order : bool, optional, default is True + Whether dimensions must be in the same order. See Also -------- @@ -137,6 +151,7 @@ def assert_equal(a, b, from_root=True): assert ( type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates) ) + b = maybe_transpose_dims(a, b, check_dim_order) if isinstance(a, (Variable, DataArray)): assert a.equals(b), formatting.diff_array_repr(a, b, "equals") elif isinstance(a, Dataset): @@ -182,6 +197,8 @@ def assert_identical(a, b, from_root=True): Only used when comparing DataTree objects. Indicates whether or not to first traverse to the root of the trees before checking for isomorphism. If a & b have no parents then this has no effect. + check_dim_order : bool, optional, default is True + Whether dimensions must be in the same order. See Also -------- @@ -213,7 +230,9 @@ def assert_identical(a, b, from_root=True): @ensure_warnings -def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): +def assert_allclose( + a, b, rtol=1e-05, atol=1e-08, decode_bytes=True, check_dim_order: bool = True +): """Like :py:func:`numpy.testing.assert_allclose`, but for xarray objects. Raises an AssertionError if two objects are not equal up to desired @@ -233,6 +252,8 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): Whether byte dtypes should be decoded to strings as UTF-8 or not. This is useful for testing serialization methods on Python 3 that return saved strings as bytes. + check_dim_order : bool, optional, default is True + Whether dimensions must be in the same order. See Also -------- @@ -240,16 +261,16 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): """ __tracebackhide__ = True assert type(a) == type(b) + b = maybe_transpose_dims(a, b, check_dim_order) equiv = functools.partial( _data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes ) - equiv.__name__ = "allclose" + equiv.__name__ = "allclose" # type: ignore[attr-defined] def compat_variable(a, b): a = getattr(a, "variable", a) b = getattr(b, "variable", b) - return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data)) if isinstance(a, Variable): diff --git a/xarray/tests/test_assertions.py b/xarray/tests/test_assertions.py index f7e49a0f3de..aa0ea46f7db 100644 --- a/xarray/tests/test_assertions.py +++ b/xarray/tests/test_assertions.py @@ -57,6 +57,25 @@ def test_allclose_regression() -> None: def test_assert_allclose(obj1, obj2) -> None: with pytest.raises(AssertionError): xr.testing.assert_allclose(obj1, obj2) + with pytest.raises(AssertionError): + xr.testing.assert_allclose(obj1, obj2, check_dim_order=False) + + +@pytest.mark.parametrize("func", ["assert_equal", "assert_allclose"]) +def test_assert_allclose_equal_transpose(func) -> None: + """Transposed DataArray raises assertion unless check_dim_order=False.""" + obj1 = xr.DataArray([[0, 1, 2], [2, 3, 4]], dims=["a", "b"]) + obj2 = xr.DataArray([[0, 2], [1, 3], [2, 4]], dims=["b", "a"]) + with pytest.raises(AssertionError): + getattr(xr.testing, func)(obj1, obj2) + getattr(xr.testing, func)(obj1, obj2, check_dim_order=False) + ds1 = obj1.to_dataset(name="varname") + ds1["var2"] = obj1 + ds2 = obj1.to_dataset(name="varname") + ds2["var2"] = obj1.transpose() + with pytest.raises(AssertionError): + getattr(xr.testing, func)(ds1, ds2) + getattr(xr.testing, func)(ds1, ds2, check_dim_order=False) @pytest.mark.filterwarnings("error") From 50f87269e4c7b51c4afc6f5030760d7f8e584a09 Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Sun, 5 May 2024 18:57:36 -0400 Subject: [PATCH 03/17] Simplify fast path (#9001) --- xarray/core/variable.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 2bcee5590f8..f0685882595 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -269,9 +269,8 @@ def as_compatible_data( Finally, wrap it up with an adapter if necessary. """ - if fastpath and getattr(data, "ndim", 0) > 0: - # can't use fastpath (yet) for scalars - return cast("T_DuckArray", _maybe_wrap_data(data)) + if fastpath and getattr(data, "ndim", None) is not None: + return cast("T_DuckArray", data) from xarray.core.dataarray import DataArray From faa634579f3240e6fa36694e89cec3d674f784d3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 6 May 2024 07:44:07 +0200 Subject: [PATCH 04/17] Speed up localize (#8536) Co-authored-by: Mathias Hauser --- xarray/core/missing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 8aa2ff2f042..45abc70c0d3 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -553,11 +553,11 @@ def _localize(var, indexes_coords): """ indexes = {} for dim, [x, new_x] in indexes_coords.items(): - minval = np.nanmin(new_x.values) - maxval = np.nanmax(new_x.values) + new_x_loaded = new_x.values + minval = np.nanmin(new_x_loaded) + maxval = np.nanmax(new_x_loaded) index = x.to_index() - imin = index.get_indexer([minval], method="nearest").item() - imax = index.get_indexer([maxval], method="nearest").item() + imin, imax = index.get_indexer([minval, maxval], method="nearest") indexes[dim] = slice(max(imin - 2, 0), imax + 2) indexes_coords[dim] = (x[indexes[dim]], new_x) return var.isel(**indexes), indexes_coords From c4031cd67c6e56f398414c27aab306656d8af517 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 May 2024 10:02:37 -0700 Subject: [PATCH 05/17] Bump codecov/codecov-action from 4.3.0 to 4.3.1 in the actions group (#9004) Bumps the actions group with 1 update: [codecov/codecov-action](https://github.com/codecov/codecov-action). Updates `codecov/codecov-action` from 4.3.0 to 4.3.1 - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v4.3.0...v4.3.1) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-patch dependency-group: actions ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci-additional.yaml | 8 ++++---- .github/workflows/ci.yaml | 2 +- .github/workflows/upstream-dev-ci.yaml | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 7b248b14006..f904f259c51 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -127,7 +127,7 @@ jobs: python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v4.3.0 + uses: codecov/codecov-action@v4.3.1 with: file: mypy_report/cobertura.xml flags: mypy @@ -181,7 +181,7 @@ jobs: python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v4.3.0 + uses: codecov/codecov-action@v4.3.1 with: file: mypy_report/cobertura.xml flags: mypy39 @@ -242,7 +242,7 @@ jobs: python -m pyright xarray/ - name: Upload pyright coverage to Codecov - uses: codecov/codecov-action@v4.3.0 + uses: codecov/codecov-action@v4.3.1 with: file: pyright_report/cobertura.xml flags: pyright @@ -301,7 +301,7 @@ jobs: python -m pyright xarray/ - name: Upload pyright coverage to Codecov - uses: codecov/codecov-action@v4.3.0 + uses: codecov/codecov-action@v4.3.1 with: file: pyright_report/cobertura.xml flags: pyright39 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a577312a7cc..349aa626142 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -156,7 +156,7 @@ jobs: path: pytest.xml - name: Upload code coverage to Codecov - uses: codecov/codecov-action@v4.3.0 + uses: codecov/codecov-action@v4.3.1 with: file: ./coverage.xml flags: unittests diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml index 4f3d199dd2d..a216dfd7428 100644 --- a/.github/workflows/upstream-dev-ci.yaml +++ b/.github/workflows/upstream-dev-ci.yaml @@ -143,7 +143,7 @@ jobs: run: | python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v4.3.0 + uses: codecov/codecov-action@v4.3.1 with: file: mypy_report/cobertura.xml flags: mypy From e0f2ceede29087854edfa498cfd23d36bcf4178a Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Mon, 6 May 2024 13:37:37 -0400 Subject: [PATCH 06/17] Port negative frequency fix for `pandas.date_range` to `cftime_range` (#8999) --- doc/whats-new.rst | 6 ++++++ xarray/coding/cftime_offsets.py | 7 ++++++- xarray/tests/__init__.py | 1 + xarray/tests/test_cftime_offsets.py | 17 ++++++++++++++--- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9a4601a776f..a846c1b8a01 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -60,6 +60,12 @@ Breaking changes Bug fixes ~~~~~~~~~ +- Following `an upstream bug fix + `_ to + :py:func:`pandas.date_range`, date ranges produced by + :py:func:`xarray.cftime_range` with negative frequencies will now fall fully + within the bounds of the provided start and end dates (:pull:`8999`). By + `Spencer Clark `_. Internal Changes diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 2e594455874..0af75f404a2 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -845,7 +845,12 @@ def _generate_range(start, end, periods, offset): A generator object """ if start: - start = offset.rollforward(start) + # From pandas GH 56147 / 56832 to account for negative direction and + # range bounds + if offset.n >= 0: + start = offset.rollforward(start) + else: + start = offset.rollback(start) if periods is None and end < start and offset.n >= 0: end = None diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 94c44544fb5..59fcdca8ffa 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -130,6 +130,7 @@ def _importorskip( has_numexpr, requires_numexpr = _importorskip("numexpr") has_flox, requires_flox = _importorskip("flox") has_pandas_ge_2_2, __ = _importorskip("pandas", "2.2") +has_pandas_3, requires_pandas_3 = _importorskip("pandas", "3.0.0.dev0") # some special cases diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index 0110afe40ac..eabb7d2f4d6 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -42,6 +42,7 @@ has_cftime, has_pandas_ge_2_2, requires_cftime, + requires_pandas_3, ) cftime = pytest.importorskip("cftime") @@ -1354,7 +1355,7 @@ def test_calendar_specific_month_end_negative_freq( ) -> None: year = 2000 # Use a leap-year to highlight calendar differences result = cftime_range( - start="2000-12", + start="2001", end="2000", freq="-2ME", calendar=calendar, @@ -1464,7 +1465,7 @@ def test_date_range_errors() -> None: ("2020-02-01", "QE-DEC", "noleap", "gregorian", True, "2020-03-31", True), ("2020-02-01", "YS-FEB", "noleap", "gregorian", True, "2020-02-01", True), ("2020-02-01", "YE-FEB", "noleap", "gregorian", True, "2020-02-29", True), - ("2020-02-01", "-1YE-FEB", "noleap", "gregorian", True, "2020-02-29", True), + ("2020-02-01", "-1YE-FEB", "noleap", "gregorian", True, "2019-02-28", True), ("2020-02-28", "3h", "all_leap", "gregorian", False, "2020-02-28", True), ("2020-03-30", "ME", "360_day", "gregorian", False, "2020-03-31", True), ("2020-03-31", "ME", "gregorian", "360_day", None, "2020-03-30", False), @@ -1724,7 +1725,17 @@ def test_new_to_legacy_freq_pd_freq_passthrough(freq, expected): @pytest.mark.parametrize("start", ("2000", "2001")) @pytest.mark.parametrize("end", ("2000", "2001")) @pytest.mark.parametrize( - "freq", ("MS", "-1MS", "YS", "-1YS", "ME", "-1ME", "YE", "-1YE") + "freq", + ( + "MS", + pytest.param("-1MS", marks=requires_pandas_3), + "YS", + pytest.param("-1YS", marks=requires_pandas_3), + "ME", + pytest.param("-1ME", marks=requires_pandas_3), + "YE", + pytest.param("-1YE", marks=requires_pandas_3), + ), ) def test_cftime_range_same_as_pandas(start, end, freq): result = date_range(start, end, freq=freq, calendar="standard", use_cftime=True) From c01de3997cf886b91c8134f17fa086401dbb22a7 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 6 May 2024 12:17:22 -0700 Subject: [PATCH 07/17] Fix for ruff 0.4.3 (#9007) * Fix for ruff 0.4.3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ci/min_deps_check.py | 2 +- xarray/datatree_/datatree/__init__.py | 1 - xarray/datatree_/datatree/common.py | 9 ++++----- xarray/tests/test_dataset.py | 7 ++++--- xarray/tests/test_groupby.py | 6 +++--- xarray/tests/test_rolling.py | 6 +++--- 6 files changed, 15 insertions(+), 16 deletions(-) diff --git a/ci/min_deps_check.py b/ci/min_deps_check.py index 48ea323ed81..5ec7bff0a30 100755 --- a/ci/min_deps_check.py +++ b/ci/min_deps_check.py @@ -133,7 +133,7 @@ def process_pkg( - publication date of version suggested by policy (YYYY-MM-DD) - status ("<", "=", "> (!)") """ - print("Analyzing %s..." % pkg) + print(f"Analyzing {pkg}...") versions = query_conda(pkg) try: diff --git a/xarray/datatree_/datatree/__init__.py b/xarray/datatree_/datatree/__init__.py index 3159d612913..51c5f1b3073 100644 --- a/xarray/datatree_/datatree/__init__.py +++ b/xarray/datatree_/datatree/__init__.py @@ -1,7 +1,6 @@ # import public API from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError - __all__ = ( "InvalidTreeError", "NotFoundInTreeError", diff --git a/xarray/datatree_/datatree/common.py b/xarray/datatree_/datatree/common.py index e4d52925ede..f4f74337c50 100644 --- a/xarray/datatree_/datatree/common.py +++ b/xarray/datatree_/datatree/common.py @@ -6,8 +6,9 @@ """ import warnings +from collections.abc import Hashable, Iterable, Mapping from contextlib import suppress -from typing import Any, Hashable, Iterable, List, Mapping +from typing import Any class TreeAttrAccessMixin: @@ -83,16 +84,14 @@ def __setattr__(self, name: str, value: Any) -> None: except AttributeError as e: # Don't accidentally shadow custom AttributeErrors, e.g. # DataArray.dims.setter - if str(e) != "{!r} object has no attribute {!r}".format( - type(self).__name__, name - ): + if str(e) != f"{type(self).__name__!r} object has no attribute {name!r}": raise raise AttributeError( f"cannot set attribute {name!r} on a {type(self).__name__!r} object. Use __setitem__ style" "assignment (e.g., `ds['name'] = ...`) instead of assigning variables." ) from e - def __dir__(self) -> List[str]: + def __dir__(self) -> list[str]: """Provide method name lookup and completion. Only provide 'public' methods. """ diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 301596e032f..59b5b2b9b71 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -284,7 +284,7 @@ def test_repr(self) -> None: Dimensions: (dim2: 9, dim3: 10, time: 20, dim1: 8) Coordinates: * dim2 (dim2) float64 72B 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 - * dim3 (dim3) %s 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' + * dim3 (dim3) {} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' * time (time) datetime64[ns] 160B 2000-01-01 2000-01-02 ... 2000-01-20 numbers (dim3) int64 80B 0 1 2 0 0 1 1 2 2 3 Dimensions without coordinates: dim1 @@ -293,8 +293,9 @@ def test_repr(self) -> None: var2 (dim1, dim2) float64 576B 1.162 -1.097 -2.123 ... 1.267 0.3328 var3 (dim3, dim1) float64 640B 0.5565 -0.2121 0.4563 ... -0.2452 -0.3616 Attributes: - foo: bar""" - % data["dim3"].dtype + foo: bar""".format( + data["dim3"].dtype + ) ) actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) print(actual) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e9e4eb1364c..7134fe96d01 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -577,8 +577,8 @@ def test_da_groupby_assign_coords() -> None: def test_groupby_repr(obj, dim) -> None: actual = repr(obj.groupby(dim)) expected = f"{obj.__class__.__name__}GroupBy" - expected += ", grouped over %r" % dim - expected += "\n%r groups with labels " % (len(np.unique(obj[dim]))) + expected += f", grouped over {dim!r}" + expected += f"\n{len(np.unique(obj[dim]))!r} groups with labels " if dim == "x": expected += "1, 2, 3, 4, 5." elif dim == "y": @@ -595,7 +595,7 @@ def test_groupby_repr_datetime(obj) -> None: actual = repr(obj.groupby("t.month")) expected = f"{obj.__class__.__name__}GroupBy" expected += ", grouped over 'month'" - expected += "\n%r groups with labels " % (len(np.unique(obj.t.dt.month))) + expected += f"\n{len(np.unique(obj.t.dt.month))!r} groups with labels " expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12." assert actual == expected diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 79a5ba0a667..89f6ebba2c3 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -254,7 +254,7 @@ def test_rolling_reduce( rolling_obj = da.rolling(time=window, center=center, min_periods=min_periods) # add nan prefix to numpy methods to get similar # behavior as bottleneck - actual = rolling_obj.reduce(getattr(np, "nan%s" % name)) + actual = rolling_obj.reduce(getattr(np, f"nan{name}")) expected = getattr(rolling_obj, name)() assert_allclose(actual, expected) assert actual.sizes == expected.sizes @@ -276,7 +276,7 @@ def test_rolling_reduce_nonnumeric( rolling_obj = da.rolling(time=window, center=center, min_periods=min_periods) # add nan prefix to numpy methods to get similar behavior as bottleneck - actual = rolling_obj.reduce(getattr(np, "nan%s" % name)) + actual = rolling_obj.reduce(getattr(np, f"nan{name}")) expected = getattr(rolling_obj, name)() assert_allclose(actual, expected) assert actual.sizes == expected.sizes @@ -741,7 +741,7 @@ def test_rolling_reduce(self, ds, center, min_periods, window, name) -> None: rolling_obj = ds.rolling(time=window, center=center, min_periods=min_periods) # add nan prefix to numpy methods to get similar behavior as bottleneck - actual = rolling_obj.reduce(getattr(np, "nan%s" % name)) + actual = rolling_obj.reduce(getattr(np, f"nan{name}")) expected = getattr(rolling_obj, name)() assert_allclose(actual, expected) assert ds.sizes == actual.sizes From 2ad98b132cf004d908a77c40a7dc7adbd792f668 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 May 2024 13:21:14 -0600 Subject: [PATCH 08/17] Trigger CI only if code files are modified. (#9006) * Trigger CI only if code files are modified. Fixes #8705 * Apply suggestions from code review Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --------- Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- .github/workflows/ci-additional.yaml | 6 ++++++ .github/workflows/ci.yaml | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index f904f259c51..c0f978fb0d8 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -6,6 +6,12 @@ on: pull_request: branches: - "main" + paths: + - 'ci/**' + - '.github/**' + - '/*' # covers files such as `pyproject.toml` + - 'properties/**' + - 'xarray/**' workflow_dispatch: # allows you to trigger manually concurrency: diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 349aa626142..b9b15d867a7 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -6,6 +6,12 @@ on: pull_request: branches: - "main" + paths: + - 'ci/**' + - '.github/**' + - '/*' # covers files such as `pyproject.toml` + - 'properties/**' + - 'xarray/**' workflow_dispatch: # allows you to trigger manually concurrency: From dcf2ac4addb5a92723c6b064fb6546ff02ebd1cd Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 May 2024 09:29:26 -0600 Subject: [PATCH 09/17] Zarr: Optimize `region="auto"` detection (#8997) * Zarr: Optimize region detection * Fix for unindexed dimensions. * Better example * small cleanup --- doc/user-guide/io.rst | 4 +- xarray/backends/api.py | 115 ++++------------------------------------ xarray/backends/zarr.py | 96 ++++++++++++++++++++++++++++++--- 3 files changed, 101 insertions(+), 114 deletions(-) diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index 63bf8b80d81..b73d0fdcb51 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -874,7 +874,7 @@ and then calling ``to_zarr`` with ``compute=False`` to write only metadata # The values of this dask array are entirely irrelevant; only the dtype, # shape and chunks are used dummies = dask.array.zeros(30, chunks=10) - ds = xr.Dataset({"foo": ("x", dummies)}) + ds = xr.Dataset({"foo": ("x", dummies)}, coords={"x": np.arange(30)}) path = "path/to/directory.zarr" # Now we write the metadata without computing any array values ds.to_zarr(path, compute=False) @@ -890,7 +890,7 @@ where the data should be written (in index space, not label space), e.g., # For convenience, we'll slice a single dataset, but in the real use-case # we would create them separately possibly even from separate processes. - ds = xr.Dataset({"foo": ("x", np.arange(30))}) + ds = xr.Dataset({"foo": ("x", np.arange(30))}, coords={"x": np.arange(30)}) # Any of the following region specifications are valid ds.isel(x=slice(0, 10)).to_zarr(path, region="auto") ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": "auto"}) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 62085fe5e2a..c9a8630a575 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -27,7 +27,6 @@ _normalize_path, ) from xarray.backends.locks import _get_scheduler -from xarray.backends.zarr import open_zarr from xarray.core import indexing from xarray.core.combine import ( _infer_concat_order_from_positions, @@ -1522,92 +1521,6 @@ def save_mfdataset( ) -def _auto_detect_region(ds_new, ds_orig, dim): - # Create a mapping array of coordinates to indices on the original array - coord = ds_orig[dim] - da_map = DataArray(np.arange(coord.size), coords={dim: coord}) - - try: - da_idxs = da_map.sel({dim: ds_new[dim]}) - except KeyError as e: - if "not all values found" in str(e): - raise KeyError( - f"Not all values of coordinate '{dim}' in the new array were" - " found in the original store. Writing to a zarr region slice" - " requires that no dimensions or metadata are changed by the write." - ) - else: - raise e - - if (da_idxs.diff(dim) != 1).any(): - raise ValueError( - f"The auto-detected region of coordinate '{dim}' for writing new data" - " to the original store had non-contiguous indices. Writing to a zarr" - " region slice requires that the new data constitute a contiguous subset" - " of the original store." - ) - - dim_slice = slice(da_idxs.values[0], da_idxs.values[-1] + 1) - - return dim_slice - - -def _auto_detect_regions(ds, region, open_kwargs): - ds_original = open_zarr(**open_kwargs) - for key, val in region.items(): - if val == "auto": - region[key] = _auto_detect_region(ds, ds_original, key) - return region - - -def _validate_and_autodetect_region(ds, region, mode, open_kwargs) -> dict[str, slice]: - if region == "auto": - region = {dim: "auto" for dim in ds.dims} - - if not isinstance(region, dict): - raise TypeError(f"``region`` must be a dict, got {type(region)}") - - if any(v == "auto" for v in region.values()): - if mode != "r+": - raise ValueError( - f"``mode`` must be 'r+' when using ``region='auto'``, got {mode}" - ) - region = _auto_detect_regions(ds, region, open_kwargs) - - for k, v in region.items(): - if k not in ds.dims: - raise ValueError( - f"all keys in ``region`` are not in Dataset dimensions, got " - f"{list(region)} and {list(ds.dims)}" - ) - if not isinstance(v, slice): - raise TypeError( - "all values in ``region`` must be slice objects, got " - f"region={region}" - ) - if v.step not in {1, None}: - raise ValueError( - "step on all slices in ``region`` must be 1 or None, got " - f"region={region}" - ) - - non_matching_vars = [ - k for k, v in ds.variables.items() if not set(region).intersection(v.dims) - ] - if non_matching_vars: - raise ValueError( - f"when setting `region` explicitly in to_zarr(), all " - f"variables in the dataset to write must have at least " - f"one dimension in common with the region's dimensions " - f"{list(region.keys())}, but that is not " - f"the case for some variables here. To drop these variables " - f"from this dataset before exporting to zarr, write: " - f".drop_vars({non_matching_vars!r})" - ) - - return region - - def _validate_datatypes_for_zarr_append(zstore, dataset): """If variable exists in the store, confirm dtype of the data to append is compatible with existing dtype. @@ -1768,24 +1681,6 @@ def to_zarr( # validate Dataset keys, DataArray names _validate_dataset_names(dataset) - if region is not None: - open_kwargs = dict( - store=store, - synchronizer=synchronizer, - group=group, - consolidated=consolidated, - storage_options=storage_options, - zarr_version=zarr_version, - ) - region = _validate_and_autodetect_region(dataset, region, mode, open_kwargs) - # can't modify indexed with region writes - dataset = dataset.drop_vars(dataset.indexes) - if append_dim is not None and append_dim in region: - raise ValueError( - f"cannot list the same dimension in both ``append_dim`` and " - f"``region`` with to_zarr(), got {append_dim} in both" - ) - if zarr_version is None: # default to 2 if store doesn't specify it's version (e.g. a path) zarr_version = int(getattr(store, "_store_version", 2)) @@ -1815,6 +1710,16 @@ def to_zarr( write_empty=write_empty_chunks, ) + if region is not None: + zstore._validate_and_autodetect_region(dataset) + # can't modify indexed with region writes + dataset = dataset.drop_vars(dataset.indexes) + if append_dim is not None and append_dim in region: + raise ValueError( + f"cannot list the same dimension in both ``append_dim`` and " + f"``region`` with to_zarr(), got {append_dim} in both" + ) + if mode in ["a", "a-", "r+"]: _validate_datatypes_for_zarr_append(zstore, dataset) if append_dim is not None: diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 3d6baeefe01..e4a684e945d 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any import numpy as np +import pandas as pd from xarray import coding, conventions from xarray.backends.common import ( @@ -509,7 +510,9 @@ def ds(self): # TODO: consider deprecating this in favor of zarr_group return self.zarr_group - def open_store_variable(self, name, zarr_array): + def open_store_variable(self, name, zarr_array=None): + if zarr_array is None: + zarr_array = self.zarr_group[name] data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array)) try_nczarr = self._mode == "r" dimensions, attributes = _get_zarr_dims_and_attrs( @@ -623,11 +626,7 @@ def store( # avoid needing to load index variables into memory. # TODO: consider making loading indexes lazy again? existing_vars, _, _ = conventions.decode_cf_variables( - { - k: v - for k, v in self.get_variables().items() - if k in existing_variable_names - }, + {k: self.open_store_variable(name=k) for k in existing_variable_names}, self.get_attrs(), ) # Modified variables must use the same encoding as the store. @@ -796,10 +795,93 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No region = tuple(write_region[dim] for dim in dims) writer.add(v.data, zarr_array, region) - def close(self): + def close(self) -> None: if self._close_store_on_close: self.zarr_group.store.close() + def _auto_detect_regions(self, ds, region): + for dim, val in region.items(): + if val != "auto": + continue + + if dim not in ds._variables: + # unindexed dimension + region[dim] = slice(0, ds.sizes[dim]) + continue + + variable = conventions.decode_cf_variable( + dim, self.open_store_variable(dim).compute() + ) + assert variable.dims == (dim,) + index = pd.Index(variable.data) + idxs = index.get_indexer(ds[dim].data) + if any(idxs == -1): + raise KeyError( + f"Not all values of coordinate '{dim}' in the new array were" + " found in the original store. Writing to a zarr region slice" + " requires that no dimensions or metadata are changed by the write." + ) + + if (np.diff(idxs) != 1).any(): + raise ValueError( + f"The auto-detected region of coordinate '{dim}' for writing new data" + " to the original store had non-contiguous indices. Writing to a zarr" + " region slice requires that the new data constitute a contiguous subset" + " of the original store." + ) + region[dim] = slice(idxs[0], idxs[-1] + 1) + return region + + def _validate_and_autodetect_region(self, ds) -> None: + region = self._write_region + + if region == "auto": + region = {dim: "auto" for dim in ds.dims} + + if not isinstance(region, dict): + raise TypeError(f"``region`` must be a dict, got {type(region)}") + if any(v == "auto" for v in region.values()): + if self._mode != "r+": + raise ValueError( + f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}" + ) + region = self._auto_detect_regions(ds, region) + + # validate before attempting to auto-detect since the auto-detection + # should always return a valid slice. + for k, v in region.items(): + if k not in ds.dims: + raise ValueError( + f"all keys in ``region`` are not in Dataset dimensions, got " + f"{list(region)} and {list(ds.dims)}" + ) + if not isinstance(v, slice): + raise TypeError( + "all values in ``region`` must be slice objects, got " + f"region={region}" + ) + if v.step not in {1, None}: + raise ValueError( + "step on all slices in ``region`` must be 1 or None, got " + f"region={region}" + ) + + non_matching_vars = [ + k for k, v in ds.variables.items() if not set(region).intersection(v.dims) + ] + if non_matching_vars: + raise ValueError( + f"when setting `region` explicitly in to_zarr(), all " + f"variables in the dataset to write must have at least " + f"one dimension in common with the region's dimensions " + f"{list(region.keys())}, but that is not " + f"the case for some variables here. To drop these variables " + f"from this dataset before exporting to zarr, write: " + f".drop_vars({non_matching_vars!r})" + ) + + self._write_region = region + def open_zarr( store, From 4e9d557d47bf6792937792426559747204fce5ed Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Tue, 7 May 2024 11:59:21 -0400 Subject: [PATCH 10/17] Add a benchmark to monitor performance for large dataset indexing (#9012) --- asv_bench/benchmarks/indexing.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/asv_bench/benchmarks/indexing.py b/asv_bench/benchmarks/indexing.py index 169c8af06e9..892a6cb3758 100644 --- a/asv_bench/benchmarks/indexing.py +++ b/asv_bench/benchmarks/indexing.py @@ -12,6 +12,7 @@ nt = 500 basic_indexes = { + "1scalar": {"x": 0}, "1slice": {"x": slice(0, 3)}, "1slice-1scalar": {"x": 0, "y": slice(None, None, 3)}, "2slicess-1scalar": {"x": slice(3, -3, 3), "y": 1, "t": slice(None, -3, 3)}, @@ -74,6 +75,10 @@ def setup(self, key): "x_coords": ("x", np.linspace(1.1, 2.1, nx)), }, ) + # Benchmark how indexing is slowed down by adding many scalar variable + # to the dataset + # https://github.com/pydata/xarray/pull/9003 + self.ds_large = self.ds.merge({f"extra_var{i}": i for i in range(400)}) class Indexing(Base): @@ -89,6 +94,11 @@ def time_indexing_outer(self, key): def time_indexing_vectorized(self, key): self.ds.isel(**vectorized_indexes[key]).load() + @parameterized(["key"], [list(basic_indexes.keys())]) + def time_indexing_basic_ds_large(self, key): + # https://github.com/pydata/xarray/pull/9003 + self.ds_large.isel(**basic_indexes[key]).load() + class Assignment(Base): @parameterized(["key"], [list(basic_indexes.keys())]) From 322e6706c0d85ec4c0c0763806c4edf158baa58c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 May 2024 13:28:50 -0600 Subject: [PATCH 11/17] Avoid extra read from disk when creating Pandas Index. (#8893) * Avoid extra read from disk when creating Pandas Index. * Update xarray/core/indexes.py Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --------- Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- xarray/core/indexes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index e71c4a6f073..a005e1ebfe2 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -632,7 +632,8 @@ def from_variables( # the checks below. # preserve wrapped pd.Index (if any) - data = getattr(var._data, "array", var.data) + # accessing `.data` can load data from disk, so we only access if needed + data = getattr(var._data, "array") if hasattr(var._data, "array") else var.data # multi-index level variable: get level index if isinstance(var._data, PandasMultiIndexingAdapter): level = var._data.level From 71661d5b17865be426b646136a0d5766290c8d3d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 May 2024 16:53:26 -0600 Subject: [PATCH 12/17] Fix benchmark CI (#9013) * [skip-ci] Fix benchmark CI * [skip-ci] reduce warnings * Fix indexing benchmark --- .github/workflows/benchmarks.yml | 6 +++--- asv_bench/asv.conf.json | 12 ++++++++---- asv_bench/benchmarks/groupby.py | 17 +++++++++-------- asv_bench/benchmarks/indexing.py | 1 + 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 7969847c61f..886bcfbd548 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -28,8 +28,11 @@ jobs: environment-name: xarray-tests cache-environment: true cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}-benchmark" + # add "build" because of https://github.com/airspeed-velocity/asv/issues/1385 create-args: >- asv + build + mamba - name: Run benchmarks @@ -47,9 +50,6 @@ jobs: asv machine --yes echo "Baseline: ${{ github.event.pull_request.base.sha }} (${{ github.event.pull_request.base.label }})" echo "Contender: ${GITHUB_SHA} (${{ github.event.pull_request.head.label }})" - # Use mamba for env creation - # export CONDA_EXE=$(which mamba) - export CONDA_EXE=$(which conda) # Run benchmarks for current commit against base ASV_OPTIONS="--split --show-stderr --factor $ASV_FACTOR" asv continuous $ASV_OPTIONS ${{ github.event.pull_request.base.sha }} ${GITHUB_SHA} \ diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index a709d0a51a7..9dc86df712d 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -29,7 +29,7 @@ // If missing or the empty string, the tool will be automatically // determined by looking for tools on the PATH environment // variable. - "environment_type": "conda", + "environment_type": "mamba", "conda_channels": ["conda-forge"], // timeout in seconds for installing any dependencies in environment @@ -41,7 +41,7 @@ // The Pythons you'd like to test against. If not provided, defaults // to the current version of Python used to run `asv`. - "pythons": ["3.10"], + "pythons": ["3.11"], // The matrix of dependencies to test. Each key is the name of a // package (in PyPI) and the values are version numbers. An empty @@ -72,8 +72,12 @@ "sparse": [""], "cftime": [""] }, - - + // fix for bad builds + // https://github.com/airspeed-velocity/asv/issues/1389#issuecomment-2076131185 + "build_command": [ + "python -m build", + "python -mpip wheel --no-deps --no-build-isolation --no-index -w {build_cache_dir} {build_dir}" + ], // Combinations of libraries/python versions can be excluded/included // from the set to test. Each entry is a dictionary containing additional // key-value pairs to include/exclude. diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 1b3e55fa659..065c1b3b17f 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -68,6 +68,7 @@ def setup(self, *args, **kwargs): self.ds2d_mean = self.ds2d.groupby("b").mean().compute() +# TODO: These don't work now because we are calling `.compute` explicitly. class GroupByPandasDataFrame(GroupBy): """Run groupby tests using pandas DataFrame.""" @@ -111,11 +112,11 @@ def setup(self, *args, **kwargs): { "b": ("time", np.arange(365.0 * 24)), }, - coords={"time": pd.date_range("2001-01-01", freq="H", periods=365 * 24)}, + coords={"time": pd.date_range("2001-01-01", freq="h", periods=365 * 24)}, ) self.ds2d = self.ds1d.expand_dims(z=10) - self.ds1d_mean = self.ds1d.resample(time="48H").mean() - self.ds2d_mean = self.ds2d.resample(time="48H").mean() + self.ds1d_mean = self.ds1d.resample(time="48h").mean() + self.ds2d_mean = self.ds2d.resample(time="48h").mean() @parameterized(["ndim"], [(1, 2)]) def time_init(self, ndim): @@ -127,7 +128,7 @@ def time_init(self, ndim): def time_agg_small_num_groups(self, method, ndim, use_flox): ds = getattr(self, f"ds{ndim}d") with xr.set_options(use_flox=use_flox): - getattr(ds.resample(time="3M"), method)().compute() + getattr(ds.resample(time="3ME"), method)().compute() @parameterized( ["method", "ndim", "use_flox"], [("sum", "mean"), (1, 2), (True, False)] @@ -135,7 +136,7 @@ def time_agg_small_num_groups(self, method, ndim, use_flox): def time_agg_large_num_groups(self, method, ndim, use_flox): ds = getattr(self, f"ds{ndim}d") with xr.set_options(use_flox=use_flox): - getattr(ds.resample(time="48H"), method)().compute() + getattr(ds.resample(time="48h"), method)().compute() class ResampleDask(Resample): @@ -154,13 +155,13 @@ def setup(self, *args, **kwargs): }, coords={ "time": xr.date_range( - "2001-01-01", freq="H", periods=365 * 24, calendar="noleap" + "2001-01-01", freq="h", periods=365 * 24, calendar="noleap" ) }, ) self.ds2d = self.ds1d.expand_dims(z=10) - self.ds1d_mean = self.ds1d.resample(time="48H").mean() - self.ds2d_mean = self.ds2d.resample(time="48H").mean() + self.ds1d_mean = self.ds1d.resample(time="48h").mean() + self.ds2d_mean = self.ds2d.resample(time="48h").mean() @parameterized(["use_cftime", "use_flox"], [[True, False], [True, False]]) diff --git a/asv_bench/benchmarks/indexing.py b/asv_bench/benchmarks/indexing.py index 892a6cb3758..529d023daa8 100644 --- a/asv_bench/benchmarks/indexing.py +++ b/asv_bench/benchmarks/indexing.py @@ -19,6 +19,7 @@ } basic_assignment_values = { + "1scalar": 0, "1slice": xr.DataArray(randn((3, ny), frac_nan=0.1), dims=["x", "y"]), "1slice-1scalar": xr.DataArray(randn(int(ny / 3) + 1, frac_nan=0.1), dims=["y"]), "2slicess-1scalar": xr.DataArray( From 6057128b7779611c03a546927955862b1dcd2572 Mon Sep 17 00:00:00 2001 From: Tom Nicholas Date: Wed, 8 May 2024 13:26:43 -0600 Subject: [PATCH 13/17] Avoid auto creation of indexes in concat (#8872) * test not creating indexes on concatenation * construct result dataset using Coordinates object with indexes passed explicitly * remove unnecessary overwriting of indexes * ConcatenatableArray class * use ConcatenableArray in tests * add regression tests * fix by performing check * refactor assert_valid_explicit_coords and rename dims->sizes * Revert "add regression tests" This reverts commit beb665a7109bdb627aa66ee277fd87edc195356d. * Revert "fix by performing check" This reverts commit 22f361dc590a83b2b3660539175a8a7cb1cba051. * Revert "refactor assert_valid_explicit_coords and rename dims->sizes" This reverts commit 55166fc7e002fa07d7a84f8d7fc460ddaad9674f. * fix failing test * possible fix for failing groupby test * Revert "possible fix for failing groupby test" This reverts commit 6e9ead6603de73c5ea6bd8f76d973525bb70b417. * test expand_dims doesn't create Index * add option to not create 1D index in expand_dims * refactor tests to consider data variables and coordinate variables separately * test expand_dims doesn't create Index * add option to not create 1D index in expand_dims * refactor tests to consider data variables and coordinate variables separately * fix bug causing new test to fail * test index auto-creation when iterable passed as new coordinate values * make test for iterable pass * added kwarg to dataarray * whatsnew * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "refactor tests to consider data variables and coordinate variables separately" This reverts commit ba5627eebf7b580d0a0b9a171f1f94d7412662e3. * Revert "add option to not create 1D index in expand_dims" This reverts commit 95d453ccff1d2e2746c1970c0157f2de0b582105. * test that concat doesn't raise if create_1d_index=False * make test pass by passing create_1d_index down through concat * assert that an UnexpectedDataAccess error is raised when create_1d_index=True * eliminate possibility of xarray internals bypassing UnexpectedDataAccess error by accessing .array * update tests to use private versions of assertions * create_1d_index->create_index * Update doc/whats-new.rst Co-authored-by: Deepak Cherian * Rename create_1d_index -> create_index * fix ConcatenatableArray * formatting * whatsnew * add new create_index kwarg to overloads * split vars into data_vars and coord_vars in one loop * avoid mypy error by using new variable name * warn if create_index=True but no index created because dimension variable was a data var not a coord * add string marks in warning message * regression test for dtype changing in to_stacked_array * correct doctest * Remove outdated comment * test we can skip creation of indexes during shape promotion * make shape promotion test pass * point to issue in whatsnew * don't create dimension coordinates just to drop them at the end * Remove ToDo about not using Coordinates object to pass indexes Co-authored-by: Deepak Cherian * get rid of unlabeled_dims variable entirely * move ConcatenatableArray and similar to new file * formatting nit Co-authored-by: Justus Magin * renamed create_index -> create_index_for_new_dim in concat * renamed create_index -> create_index_for_new_dim in expand_dims * fix incorrect arg name * add example to docstring * add example of using new kwarg to docstring of expand_dims * add example of using new kwarg to docstring of concat * re-nit the nit Co-authored-by: Justus Magin * more instances of the nit * fix docstring doctest formatting nit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian Co-authored-by: Justus Magin --- doc/whats-new.rst | 5 +- xarray/core/concat.py | 67 ++++++++++--- xarray/core/dataarray.py | 12 ++- xarray/core/dataset.py | 43 +++++++-- xarray/tests/__init__.py | 54 ++--------- xarray/tests/arrays.py | 179 +++++++++++++++++++++++++++++++++++ xarray/tests/test_concat.py | 88 +++++++++++++++++ xarray/tests/test_dataset.py | 41 ++++++-- 8 files changed, 405 insertions(+), 84 deletions(-) create mode 100644 xarray/tests/arrays.py diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a846c1b8a01..378e6330352 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -32,7 +32,10 @@ New Features - :py:func:`testing.assert_allclose`/:py:func:`testing.assert_equal` now accept a new argument `check_dims="transpose"`, controlling whether a transposed array is considered equal. (:issue:`5733`, :pull:`8991`) By `Ignacio Martinez Vazquez `_. - Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg - `create_index=False`. (:pull:`8960`) + `create_index_for_new_dim=False`. (:pull:`8960`) + By `Tom Nicholas `_. +- Avoid automatically re-creating 1D pandas indexes in :py:func:`concat()`. Also added option to avoid creating 1D indexes for + new dimension coordinates by passing the new kwarg `create_index_for_new_dim=False`. (:issue:`8871`, :pull:`8872`) By `Tom Nicholas `_. Breaking changes diff --git a/xarray/core/concat.py b/xarray/core/concat.py index d95cbccd36a..b1cca586992 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -8,6 +8,7 @@ from xarray.core import dtypes, utils from xarray.core.alignment import align, reindex_variables +from xarray.core.coordinates import Coordinates from xarray.core.duck_array_ops import lazy_array_equiv from xarray.core.indexes import Index, PandasIndex from xarray.core.merge import ( @@ -42,6 +43,7 @@ def concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, ) -> T_Dataset: ... @@ -56,6 +58,7 @@ def concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, ) -> T_DataArray: ... @@ -69,6 +72,7 @@ def concat( fill_value=dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, ): """Concatenate xarray objects along a new or existing dimension. @@ -162,6 +166,8 @@ def concat( If a callable, it must expect a sequence of ``attrs`` dicts and a context object as its only parameters. + create_index_for_new_dim : bool, default: True + Whether to create a new ``PandasIndex`` object when the objects being concatenated contain scalar variables named ``dim``. Returns ------- @@ -217,6 +223,25 @@ def concat( x (new_dim) >> ds = xr.Dataset(coords={"x": 0}) + >>> xr.concat([ds, ds], dim="x") + Size: 16B + Dimensions: (x: 2) + Coordinates: + * x (x) int64 16B 0 0 + Data variables: + *empty* + + >>> xr.concat([ds, ds], dim="x").indexes + Indexes: + x Index([0, 0], dtype='int64', name='x') + + >>> xr.concat([ds, ds], dim="x", create_index_for_new_dim=False).indexes + Indexes: + *empty* """ # TODO: add ignore_index arguments copied from pandas.concat # TODO: support concatenating scalar coordinates even if the concatenated @@ -245,6 +270,7 @@ def concat( fill_value=fill_value, join=join, combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, ) elif isinstance(first_obj, Dataset): return _dataset_concat( @@ -257,6 +283,7 @@ def concat( fill_value=fill_value, join=join, combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, ) else: raise TypeError( @@ -439,7 +466,7 @@ def _parse_datasets( if dim in dims: continue - if dim not in dim_coords: + if dim in ds.coords and dim not in dim_coords: dim_coords[dim] = ds.coords[dim].variable dims = dims | set(ds.dims) @@ -456,6 +483,7 @@ def _dataset_concat( fill_value: Any = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, ) -> T_Dataset: """ Concatenate a sequence of datasets along a new or existing dimension @@ -489,7 +517,6 @@ def _dataset_concat( datasets ) dim_names = set(dim_coords) - unlabeled_dims = dim_names - coord_names both_data_and_coords = coord_names & data_names if both_data_and_coords: @@ -502,7 +529,10 @@ def _dataset_concat( # case where concat dimension is a coordinate or data_var but not a dimension if (dim in coord_names or dim in data_names) and dim not in dim_names: - datasets = [ds.expand_dims(dim) for ds in datasets] + datasets = [ + ds.expand_dims(dim, create_index_for_new_dim=create_index_for_new_dim) + for ds in datasets + ] # determine which variables to concatenate concat_over, equals, concat_dim_lengths = _calc_concat_over( @@ -510,7 +540,7 @@ def _dataset_concat( ) # determine which variables to merge, and then merge them according to compat - variables_to_merge = (coord_names | data_names) - concat_over - unlabeled_dims + variables_to_merge = (coord_names | data_names) - concat_over result_vars = {} result_indexes = {} @@ -567,7 +597,8 @@ def get_indexes(name): var = ds._variables[name] if not var.dims: data = var.set_dims(dim).values - yield PandasIndex(data, dim, coord_dtype=var.dtype) + if create_index_for_new_dim: + yield PandasIndex(data, dim, coord_dtype=var.dtype) # create concatenation index, needed for later reindexing file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths)) @@ -646,29 +677,33 @@ def get_indexes(name): # preserves original variable order result_vars[name] = result_vars.pop(name) - result = type(datasets[0])(result_vars, attrs=result_attrs) - - absent_coord_names = coord_names - set(result.variables) + absent_coord_names = coord_names - set(result_vars) if absent_coord_names: raise ValueError( f"Variables {absent_coord_names!r} are coordinates in some datasets but not others." ) - result = result.set_coords(coord_names) - result.encoding = result_encoding - result = result.drop_vars(unlabeled_dims, errors="ignore") + result_data_vars = {} + coord_vars = {} + for name, result_var in result_vars.items(): + if name in coord_names: + coord_vars[name] = result_var + else: + result_data_vars[name] = result_var if index is not None: - # add concat index / coordinate last to ensure that its in the final Dataset if dim_var is not None: index_vars = index.create_variables({dim: dim_var}) else: index_vars = index.create_variables() - result[dim] = index_vars[dim] + + coord_vars[dim] = index_vars[dim] result_indexes[dim] = index - # TODO: add indexes at Dataset creation (when it is supported) - result = result._overwrite_indexes(result_indexes) + coords_obj = Coordinates(coord_vars, indexes=result_indexes) + + result = type(datasets[0])(result_data_vars, coords=coords_obj, attrs=result_attrs) + result.encoding = result_encoding return result @@ -683,6 +718,7 @@ def _dataarray_concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, ) -> T_DataArray: from xarray.core.dataarray import DataArray @@ -719,6 +755,7 @@ def _dataarray_concat( fill_value=fill_value, join=join, combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, ) merged_attrs = merge_attrs([da.attrs for da in arrays], combine_attrs) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c89dedf1215..4dc897c1878 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2558,7 +2558,7 @@ def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, - create_index: bool = True, + create_index_for_new_dim: bool = True, **dim_kwargs: Any, ) -> Self: """Return a new object with an additional axis (or axes) inserted at @@ -2569,7 +2569,7 @@ def expand_dims( coordinate consisting of a single value. The automatic creation of indexes to back new 1D coordinate variables - controlled by the create_index kwarg. + controlled by the create_index_for_new_dim kwarg. Parameters ---------- @@ -2586,8 +2586,8 @@ def expand_dims( multiple axes are inserted. In this case, dim arguments should be same length list. If axis=None is passed, all the axes will be inserted to the start of the result array. - create_index : bool, default is True - Whether to create new PandasIndex objects for any new 1D coordinate variables. + create_index_for_new_dim : bool, default: True + Whether to create new ``PandasIndex`` objects when the object being expanded contains scalar variables with names in ``dim``. **dim_kwargs : int or sequence or ndarray The keywords are arbitrary dimensions being inserted and the values are either the lengths of the new dims (if int is given), or their @@ -2651,7 +2651,9 @@ def expand_dims( dim = {dim: 1} dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims") - ds = self._to_temp_dataset().expand_dims(dim, axis, create_index=create_index) + ds = self._to_temp_dataset().expand_dims( + dim, axis, create_index_for_new_dim=create_index_for_new_dim + ) return self._from_temp_dataset(ds) def set_index( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2ddcacd2fa0..09597670573 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4513,7 +4513,7 @@ def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, - create_index: bool = True, + create_index_for_new_dim: bool = True, **dim_kwargs: Any, ) -> Self: """Return a new object with an additional axis (or axes) inserted at @@ -4524,7 +4524,7 @@ def expand_dims( coordinate consisting of a single value. The automatic creation of indexes to back new 1D coordinate variables - controlled by the create_index kwarg. + controlled by the create_index_for_new_dim kwarg. Parameters ---------- @@ -4541,8 +4541,8 @@ def expand_dims( multiple axes are inserted. In this case, dim arguments should be same length list. If axis=None is passed, all the axes will be inserted to the start of the result array. - create_index : bool, default is True - Whether to create new PandasIndex objects for any new 1D coordinate variables. + create_index_for_new_dim : bool, default: True + Whether to create new ``PandasIndex`` objects when the object being expanded contains scalar variables with names in ``dim``. **dim_kwargs : int or sequence or ndarray The keywords are arbitrary dimensions being inserted and the values are either the lengths of the new dims (if int is given), or their @@ -4612,6 +4612,33 @@ def expand_dims( Data variables: temperature (y, x, time) float64 96B 0.5488 0.7152 0.6028 ... 0.7917 0.5289 + # Expand a scalar variable along a new dimension of the same name with and without creating a new index + + >>> ds = xr.Dataset(coords={"x": 0}) + >>> ds + Size: 8B + Dimensions: () + Coordinates: + x int64 8B 0 + Data variables: + *empty* + + >>> ds.expand_dims("x") + Size: 8B + Dimensions: (x: 1) + Coordinates: + * x (x) int64 8B 0 + Data variables: + *empty* + + >>> ds.expand_dims("x").indexes + Indexes: + x Index([0], dtype='int64', name='x') + + >>> ds.expand_dims("x", create_index_for_new_dim=False).indexes + Indexes: + *empty* + See Also -------- DataArray.expand_dims @@ -4663,7 +4690,7 @@ def expand_dims( # value within the dim dict to the length of the iterable # for later use. - if create_index: + if create_index_for_new_dim: index = PandasIndex(v, k) indexes[k] = index name_and_new_1d_var = index.create_variables() @@ -4705,14 +4732,14 @@ def expand_dims( variables[k] = v.set_dims(dict(all_dims)) else: if k not in variables: - if k in coord_names and create_index: + if k in coord_names and create_index_for_new_dim: # If dims includes a label of a non-dimension coordinate, # it will be promoted to a 1D coordinate with a single value. index, index_vars = create_default_index_implicit(v.set_dims(k)) indexes[k] = index variables.update(index_vars) else: - if create_index: + if create_index_for_new_dim: warnings.warn( f"No index created for dimension {k} because variable {k} is not a coordinate. " f"To create an index for {k}, please first call `.set_coords('{k}')` on this object.", @@ -5400,7 +5427,7 @@ def to_stacked_array( [3, 4, 5, 7]]) Coordinates: * z (z) object 32B MultiIndex - * variable (z) object 32B 'a' 'a' 'a' 'b' + * variable (z) 1: - raise UnexpectedDataAccess("Tried accessing more than one element.") - return self.array[tuple_idxr] - - -class DuckArrayWrapper(utils.NDArrayMixin): - """Array-like that prevents casting to array. - Modeled after cupy.""" - - def __init__(self, array: np.ndarray): - self.array = array - - def __getitem__(self, key): - return type(self)(self.array[key]) - - def __array__(self, dtype: np.typing.DTypeLike = None): - raise UnexpectedDataAccess("Tried accessing data") - - def __array_namespace__(self): - """Present to satisfy is_duck_array test.""" - - class ReturnItem: def __getitem__(self, key): return key diff --git a/xarray/tests/arrays.py b/xarray/tests/arrays.py new file mode 100644 index 00000000000..983e620d1f0 --- /dev/null +++ b/xarray/tests/arrays.py @@ -0,0 +1,179 @@ +from collections.abc import Iterable +from typing import Any, Callable + +import numpy as np + +from xarray.core import utils +from xarray.core.indexing import ExplicitlyIndexed + +""" +This module contains various lazy array classes which can be wrapped and manipulated by xarray objects but will raise on data access. +""" + + +class UnexpectedDataAccess(Exception): + pass + + +class InaccessibleArray(utils.NDArrayMixin, ExplicitlyIndexed): + """Disallows any loading.""" + + def __init__(self, array): + self.array = array + + def get_duck_array(self): + raise UnexpectedDataAccess("Tried accessing data") + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") + + def __getitem__(self, key): + raise UnexpectedDataAccess("Tried accessing data.") + + +class FirstElementAccessibleArray(InaccessibleArray): + def __getitem__(self, key): + tuple_idxr = key.tuple + if len(tuple_idxr) > 1: + raise UnexpectedDataAccess("Tried accessing more than one element.") + return self.array[tuple_idxr] + + +class DuckArrayWrapper(utils.NDArrayMixin): + """Array-like that prevents casting to array. + Modeled after cupy.""" + + def __init__(self, array: np.ndarray): + self.array = array + + def __getitem__(self, key): + return type(self)(self.array[key]) + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") + + def __array_namespace__(self): + """Present to satisfy is_duck_array test.""" + + +CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS: dict[str, Callable] = {} + + +def implements(numpy_function): + """Register an __array_function__ implementation for ConcatenatableArray objects.""" + + def decorator(func): + CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +@implements(np.concatenate) +def concatenate( + arrays: Iterable["ConcatenatableArray"], /, *, axis=0 +) -> "ConcatenatableArray": + if any(not isinstance(arr, ConcatenatableArray) for arr in arrays): + raise TypeError + + result = np.concatenate([arr._array for arr in arrays], axis=axis) + return ConcatenatableArray(result) + + +@implements(np.stack) +def stack( + arrays: Iterable["ConcatenatableArray"], /, *, axis=0 +) -> "ConcatenatableArray": + if any(not isinstance(arr, ConcatenatableArray) for arr in arrays): + raise TypeError + + result = np.stack([arr._array for arr in arrays], axis=axis) + return ConcatenatableArray(result) + + +@implements(np.result_type) +def result_type(*arrays_and_dtypes) -> np.dtype: + """Called by xarray to ensure all arguments to concat have the same dtype.""" + first_dtype, *other_dtypes = (np.dtype(obj) for obj in arrays_and_dtypes) + for other_dtype in other_dtypes: + if other_dtype != first_dtype: + raise ValueError("dtypes not all consistent") + return first_dtype + + +@implements(np.broadcast_to) +def broadcast_to( + x: "ConcatenatableArray", /, shape: tuple[int, ...] +) -> "ConcatenatableArray": + """ + Broadcasts an array to a specified shape, by either manipulating chunk keys or copying chunk manifest entries. + """ + if not isinstance(x, ConcatenatableArray): + raise TypeError + + result = np.broadcast_to(x._array, shape=shape) + return ConcatenatableArray(result) + + +class ConcatenatableArray: + """Disallows loading or coercing to an index but does support concatenation / stacking.""" + + def __init__(self, array): + # use ._array instead of .array because we don't want this to be accessible even to xarray's internals (e.g. create_default_index_implicit) + self._array = array + + @property + def dtype(self: Any) -> np.dtype: + return self._array.dtype + + @property + def shape(self: Any) -> tuple[int, ...]: + return self._array.shape + + @property + def ndim(self: Any) -> int: + return self._array.ndim + + def __repr__(self: Any) -> str: + return f"{type(self).__name__}(array={self._array!r})" + + def get_duck_array(self): + raise UnexpectedDataAccess("Tried accessing data") + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") + + def __getitem__(self, key) -> "ConcatenatableArray": + """Some cases of concat require supporting expanding dims by dimensions of size 1""" + # see https://data-apis.org/array-api/2022.12/API_specification/indexing.html#multi-axis-indexing + arr = self._array + for axis, indexer_1d in enumerate(key): + if indexer_1d is None: + arr = np.expand_dims(arr, axis) + elif indexer_1d is Ellipsis: + pass + else: + raise UnexpectedDataAccess("Tried accessing data.") + return ConcatenatableArray(arr) + + def __array_function__(self, func, types, args, kwargs) -> Any: + if func not in CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS: + return NotImplemented + + # Note: this allows subclasses that don't override + # __array_function__ to handle ManifestArray objects + if not all(issubclass(t, ConcatenatableArray) for t in types): + return NotImplemented + + return CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[func](*args, **kwargs) + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Any: + """We have to define this in order to convince xarray that this class is a duckarray, even though we will never support ufuncs.""" + return NotImplemented + + def astype(self, dtype: np.dtype, /, *, copy: bool = True) -> "ConcatenatableArray": + """Needed because xarray will call this even when it's a no-op""" + if dtype != self.dtype: + raise NotImplementedError() + else: + return self diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 1ddb5a569bd..0c570de3b52 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -12,7 +12,9 @@ from xarray.core.coordinates import Coordinates from xarray.core.indexes import PandasIndex from xarray.tests import ( + ConcatenatableArray, InaccessibleArray, + UnexpectedDataAccess, assert_array_equal, assert_equal, assert_identical, @@ -999,6 +1001,63 @@ def test_concat_str_dtype(self, dtype, dim) -> None: assert np.issubdtype(actual.x2.dtype, dtype) + def test_concat_avoids_index_auto_creation(self) -> None: + # TODO once passing indexes={} directly to Dataset constructor is allowed then no need to create coords first + coords = Coordinates( + {"x": ConcatenatableArray(np.array([1, 2, 3]))}, indexes={} + ) + datasets = [ + Dataset( + {"a": (["x", "y"], ConcatenatableArray(np.zeros((3, 3))))}, + coords=coords, + ) + for _ in range(2) + ] + # should not raise on concat + combined = concat(datasets, dim="x") + assert combined["a"].shape == (6, 3) + assert combined["a"].dims == ("x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + + # should not raise on stack + combined = concat(datasets, dim="z") + assert combined["a"].shape == (2, 3, 3) + assert combined["a"].dims == ("z", "x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + + def test_concat_avoids_index_auto_creation_new_1d_coord(self) -> None: + # create 0D coordinates (without indexes) + datasets = [ + Dataset( + coords={"x": ConcatenatableArray(np.array(10))}, + ) + for _ in range(2) + ] + + with pytest.raises(UnexpectedDataAccess): + concat(datasets, dim="x", create_index_for_new_dim=True) + + # should not raise on concat iff create_index_for_new_dim=False + combined = concat(datasets, dim="x", create_index_for_new_dim=False) + assert combined["x"].shape == (2,) + assert combined["x"].dims == ("x",) + + # nor have auto-created any indexes + assert combined.indexes == {} + + def test_concat_promote_shape_without_creating_new_index(self) -> None: + # different shapes but neither have indexes + ds1 = Dataset(coords={"x": 0}) + ds2 = Dataset(data_vars={"x": [1]}).drop_indexes("x") + actual = concat([ds1, ds2], dim="x", create_index_for_new_dim=False) + expected = Dataset(data_vars={"x": [0, 1]}).drop_indexes("x") + assert_identical(actual, expected, check_default_indexes=False) + assert actual.indexes == {} + class TestConcatDataArray: def test_concat(self) -> None: @@ -1072,6 +1131,35 @@ def test_concat_lazy(self) -> None: assert combined.shape == (2, 3, 3) assert combined.dims == ("z", "x", "y") + def test_concat_avoids_index_auto_creation(self) -> None: + # TODO once passing indexes={} directly to DataArray constructor is allowed then no need to create coords first + coords = Coordinates( + {"x": ConcatenatableArray(np.array([1, 2, 3]))}, indexes={} + ) + arrays = [ + DataArray( + ConcatenatableArray(np.zeros((3, 3))), + dims=["x", "y"], + coords=coords, + ) + for _ in range(2) + ] + # should not raise on concat + combined = concat(arrays, dim="x") + assert combined.shape == (6, 3) + assert combined.dims == ("x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + + # should not raise on stack + combined = concat(arrays, dim="z") + assert combined.shape == (2, 3, 3) + assert combined.dims == ("z", "x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_concat_fill_value(self, fill_value) -> None: foo = DataArray([1, 2], coords=[("x", [1, 2])]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 59b5b2b9b71..584776197e3 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3431,16 +3431,22 @@ def test_expand_dims_kwargs_python36plus(self) -> None: ) assert_identical(other_way_expected, other_way) - @pytest.mark.parametrize("create_index_flag", [True, False]) - def test_expand_dims_create_index_data_variable(self, create_index_flag): + @pytest.mark.parametrize("create_index_for_new_dim_flag", [True, False]) + def test_expand_dims_create_index_data_variable( + self, create_index_for_new_dim_flag + ): # data variables should not gain an index ever ds = Dataset({"x": 0}) - if create_index_flag: + if create_index_for_new_dim_flag: with pytest.warns(UserWarning, match="No index created"): - expanded = ds.expand_dims("x", create_index=create_index_flag) + expanded = ds.expand_dims( + "x", create_index_for_new_dim=create_index_for_new_dim_flag + ) else: - expanded = ds.expand_dims("x", create_index=create_index_flag) + expanded = ds.expand_dims( + "x", create_index_for_new_dim=create_index_for_new_dim_flag + ) # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 expected = Dataset({"x": ("x", [0])}).drop_indexes("x").reset_coords("x") @@ -3449,13 +3455,13 @@ def test_expand_dims_create_index_data_variable(self, create_index_flag): assert expanded.indexes == {} def test_expand_dims_create_index_coordinate_variable(self): - # coordinate variables should gain an index only if create_index is True (the default) + # coordinate variables should gain an index only if create_index_for_new_dim is True (the default) ds = Dataset(coords={"x": 0}) expanded = ds.expand_dims("x") expected = Dataset({"x": ("x", [0])}) assert_identical(expanded, expected) - expanded_no_index = ds.expand_dims("x", create_index=False) + expanded_no_index = ds.expand_dims("x", create_index_for_new_dim=False) # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 expected = Dataset(coords={"x": ("x", [0])}).drop_indexes("x") @@ -3469,7 +3475,7 @@ def test_expand_dims_create_index_from_iterable(self): expected = Dataset({"x": ("x", [0, 1])}) assert_identical(expanded, expected) - expanded_no_index = ds.expand_dims(x=[0, 1], create_index=False) + expanded_no_index = ds.expand_dims(x=[0, 1], create_index_for_new_dim=False) # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 expected = Dataset(coords={"x": ("x", [0, 1])}).drop_indexes("x") @@ -3971,6 +3977,25 @@ def test_to_stacked_array_to_unstacked_dataset_different_dimension(self) -> None x = y.to_unstacked_dataset("features") assert_identical(D, x) + def test_to_stacked_array_preserves_dtype(self) -> None: + # regression test for bug found in https://github.com/pydata/xarray/pull/8872#issuecomment-2081218616 + ds = xr.Dataset( + data_vars={ + "a": (("x", "y"), [[0, 1, 2], [3, 4, 5]]), + "b": ("x", [6, 7]), + }, + coords={"y": ["u", "v", "w"]}, + ) + stacked = ds.to_stacked_array("z", sample_dims=["x"]) + + # coordinate created from variables names should be of string dtype + data = np.array(["a", "a", "a", "b"], dtype=" None: data = create_test_data(seed=0) expected = data.copy() From 30299439cae3f1102f60b75620abeced430fcd89 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Date: Tue, 9 Apr 2024 17:43:28 -0700 Subject: [PATCH 14/17] temporary enable CI triggers on feature branch --- .github/workflows/ci-additional.yaml | 2 ++ .github/workflows/ci.yaml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index c0f978fb0d8..bc2eb8d2cac 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -3,6 +3,7 @@ on: push: branches: - "main" + - "backend-indexing" pull_request: branches: - "main" @@ -12,6 +13,7 @@ on: - '/*' # covers files such as `pyproject.toml` - 'properties/**' - 'xarray/**' + - "backend-indexing" workflow_dispatch: # allows you to trigger manually concurrency: diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b9b15d867a7..ca9ef397962 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -3,6 +3,7 @@ on: push: branches: - "main" + - "backend-indexing" pull_request: branches: - "main" @@ -12,6 +13,7 @@ on: - '/*' # covers files such as `pyproject.toml` - 'properties/**' - 'xarray/**' + - "backend-indexing" workflow_dispatch: # allows you to trigger manually concurrency: From ddd4cdb59a5793b9a15a28b6b0475eed95739916 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Date: Tue, 16 Apr 2024 18:53:22 -0700 Subject: [PATCH 15/17] add `.oindex` and `.vindex` to `BackendArray` (#8885) * add .oindex and .vindex to BackendArray * Add support for .oindex and .vindex in H5NetCDFArrayWrapper * Add support for .oindex and .vindex in NetCDF4ArrayWrapper, PydapArrayWrapper, NioArrayWrapper, and ZarrArrayWrapper * add deprecation warning * Fix deprecation warning message formatting * add tests * Update xarray/core/indexing.py Co-authored-by: Deepak Cherian * Update ZarrArrayWrapper class in xarray/backends/zarr.py Co-authored-by: Deepak Cherian --------- Co-authored-by: Deepak Cherian --- xarray/backends/common.py | 18 +++++++++++++ xarray/backends/h5netcdf_.py | 12 ++++++++- xarray/backends/netCDF4_.py | 12 ++++++++- xarray/backends/pydap_.py | 12 ++++++++- xarray/backends/scipy_.py | 33 ++++++++++++++++------- xarray/backends/zarr.py | 49 ++++++++++++++++++++++------------- xarray/core/indexing.py | 36 ++++++++++++++++++++----- xarray/tests/test_backends.py | 46 ++++++++++++++++++++++++++++++++ 8 files changed, 182 insertions(+), 36 deletions(-) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index f318b4dd42f..f8f073f86a1 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -210,6 +210,24 @@ def get_duck_array(self, dtype: np.typing.DTypeLike = None): key = indexing.BasicIndexer((slice(None),) * self.ndim) return self[key] # type: ignore [index] + def _oindex_get(self, key: indexing.OuterIndexer): + raise NotImplementedError( + f"{self.__class__.__name__}._oindex_get method should be overridden" + ) + + def _vindex_get(self, key: indexing.VectorizedIndexer): + raise NotImplementedError( + f"{self.__class__.__name__}._vindex_get method should be overridden" + ) + + @property + def oindex(self) -> indexing.IndexCallable: + return indexing.IndexCallable(self._oindex_get) + + @property + def vindex(self) -> indexing.IndexCallable: + return indexing.IndexCallable(self._vindex_get) + class AbstractDataStore: __slots__ = () diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 71463193939..07973c3cbd9 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -48,7 +48,17 @@ def get_array(self, needs_lock=True): ds = self.datastore._acquire(needs_lock) return ds.variables[self.variable_name] - def __getitem__(self, key): + def _oindex_get(self, key: indexing.OuterIndexer): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem + ) + + def _vindex_get(self, key: indexing.VectorizedIndexer): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem + ) + + def __getitem__(self, key: indexing.BasicIndexer): return indexing.explicit_indexing_adapter( key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem ) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index ae86c4ce384..33d636b59cf 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -97,7 +97,17 @@ def get_array(self, needs_lock=True): variable.set_auto_chartostring(False) return variable - def __getitem__(self, key): + def _oindex_get(self, key: indexing.OuterIndexer): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER, self._getitem + ) + + def _vindex_get(self, key: indexing.VectorizedIndexer): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER, self._getitem + ) + + def __getitem__(self, key: indexing.BasicIndexer): return indexing.explicit_indexing_adapter( key, self.shape, indexing.IndexingSupport.OUTER, self._getitem ) diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 5a475a7c3be..2ce3a579b2d 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -43,7 +43,17 @@ def shape(self) -> tuple[int, ...]: def dtype(self): return self.array.dtype - def __getitem__(self, key): + def _oindex_get(self, key: indexing.OuterIndexer): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.BASIC, self._getitem + ) + + def _vindex_get(self, key: indexing.VectorizedIndexer): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.BASIC, self._getitem + ) + + def __getitem__(self, key: indexing.BasicIndexer): return indexing.explicit_indexing_adapter( key, self.shape, indexing.IndexingSupport.BASIC, self._getitem ) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index f8c486e512c..cd2217c567f 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -67,15 +67,7 @@ def get_variable(self, needs_lock=True): ds = self.datastore._manager.acquire(needs_lock) return ds.variables[self.variable_name] - def _getitem(self, key): - with self.datastore.lock: - data = self.get_variable(needs_lock=False).data - return data[key] - - def __getitem__(self, key): - data = indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem - ) + def _finalize_result(self, data): # Copy data if the source file is mmapped. This makes things consistent # with the netCDF4 library by ensuring we can safely read arrays even # after closing associated files. @@ -88,6 +80,29 @@ def __getitem__(self, key): return np.array(data, dtype=self.dtype, copy=copy) + def _getitem(self, key): + with self.datastore.lock: + data = self.get_variable(needs_lock=False).data + return data[key] + + def _vindex_get(self, key: indexing.VectorizedIndexer): + data = indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem + ) + return self._finalize_result(data) + + def _oindex_get(self, key: indexing.OuterIndexer): + data = indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem + ) + return self._finalize_result(data) + + def __getitem__(self, key): + data = indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem + ) + return self._finalize_result(data) + def __setitem__(self, key, value): with self.datastore.lock: data = self.get_variable(needs_lock=False) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index e4a684e945d..4c2e8be0c16 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -85,25 +85,38 @@ def __init__(self, zarr_array): def get_array(self): return self._array - def _oindex(self, key): - return self._array.oindex[key] - - def _vindex(self, key): - return self._array.vindex[key] - - def _getitem(self, key): - return self._array[key] - - def __getitem__(self, key): - array = self._array - if isinstance(key, indexing.BasicIndexer): - method = self._getitem - elif isinstance(key, indexing.VectorizedIndexer): - method = self._vindex - elif isinstance(key, indexing.OuterIndexer): - method = self._oindex + def _oindex_get(self, key: indexing.OuterIndexer): + def raw_indexing_method(key): + return self._array.oindex[key] + + return indexing.explicit_indexing_adapter( + key, + self._array.shape, + indexing.IndexingSupport.VECTORIZED, + raw_indexing_method, + ) + + def _vindex_get(self, key: indexing.VectorizedIndexer): + + def raw_indexing_method(key): + return self._array.vindex[key] + + return indexing.explicit_indexing_adapter( + key, + self._array.shape, + indexing.IndexingSupport.VECTORIZED, + raw_indexing_method, + ) + + def __getitem__(self, key: indexing.BasicIndexer): + def raw_indexing_method(key): + return self._array[key] + return indexing.explicit_indexing_adapter( - key, array.shape, indexing.IndexingSupport.VECTORIZED, method + key, + self._array.shape, + indexing.IndexingSupport.VECTORIZED, + raw_indexing_method, ) # if self.ndim == 0: diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 0926da6fd80..7d6191883e1 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -3,6 +3,7 @@ import enum import functools import operator +import warnings from collections import Counter, defaultdict from collections.abc import Hashable, Iterable, Mapping from contextlib import suppress @@ -588,6 +589,14 @@ def __getitem__(self, key: Any): return result +BackendArray_fallback_warning_message = ( + "The array `{0}` does not support indexing using the .vindex and .oindex properties. " + "The __getitem__ method is being used instead. This fallback behavior will be " + "removed in a future version. Please ensure that the backend array `{1}` implements " + "support for the .vindex and .oindex properties to avoid potential issues." +) + + class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin): """Wrap an array to make basic and outer indexing lazy.""" @@ -639,11 +648,18 @@ def shape(self) -> _Shape: return tuple(shape) def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): + try: array = apply_indexer(self.array, self.key) - else: + except NotImplementedError as _: # If the array is not an ExplicitlyIndexedNDArrayMixin, - # it may wrap a BackendArray so use its __getitem__ + # it may wrap a BackendArray subclass that doesn't implement .oindex and .vindex. so use its __getitem__ + warnings.warn( + BackendArray_fallback_warning_message.format( + self.array.__class__.__name__, self.array.__class__.__name__ + ), + category=DeprecationWarning, + stacklevel=2, + ) array = self.array[self.key] # self.array[self.key] is now a numpy array when @@ -715,12 +731,20 @@ def shape(self) -> _Shape: return np.broadcast(*self.key.tuple).shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): + try: array = apply_indexer(self.array, self.key) - else: + except NotImplementedError as _: # If the array is not an ExplicitlyIndexedNDArrayMixin, - # it may wrap a BackendArray so use its __getitem__ + # it may wrap a BackendArray subclass that doesn't implement .oindex and .vindex. so use its __getitem__ + warnings.warn( + BackendArray_fallback_warning_message.format( + self.array.__class__.__name__, self.array.__class__.__name__ + ), + category=PendingDeprecationWarning, + stacklevel=2, + ) array = self.array[self.key] + # self.array[self.key] is now a numpy array when # self.array is a BackendArray subclass # and self.key is BasicIndexer((slice(None, None, None),)) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0126b130e7c..d7471ecbaf9 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -5787,3 +5787,49 @@ def test_zarr_region_chunk_partial_offset(tmp_path): # This write is unsafe, and should raise an error, but does not. # with pytest.raises(ValueError): # da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto") + + +def test_backend_array_deprecation_warning(capsys): + class CustomBackendArray(xr.backends.common.BackendArray): + def __init__(self): + array = self.get_array() + self.shape = array.shape + self.dtype = array.dtype + + def get_array(self): + return np.arange(10) + + def __getitem__(self, key): + return xr.core.indexing.explicit_indexing_adapter( + key, self.shape, xr.core.indexing.IndexingSupport.BASIC, self._getitem + ) + + def _getitem(self, key): + array = self.get_array() + return array[key] + + cba = CustomBackendArray() + indexer = xr.core.indexing.VectorizedIndexer(key=(np.array([0]),)) + + la = xr.core.indexing.LazilyIndexedArray(cba, indexer) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + la.vindex[indexer].get_duck_array() + + captured = capsys.readouterr() + assert len(w) == 1 + assert issubclass(w[-1].category, PendingDeprecationWarning) + assert ( + "The array `CustomBackendArray` does not support indexing using the .vindex and .oindex properties." + in str(w[-1].message) + ) + assert "The __getitem__ method is being used instead." in str(w[-1].message) + assert "This fallback behavior will be removed in a future version." in str( + w[-1].message + ) + assert ( + "Please ensure that the backend array `CustomBackendArray` implements support for the .vindex and .oindex properties to avoid potential issues." + in str(w[-1].message) + ) + assert captured.out == "" From 96ac4b7f2879268fe03e012114a96f3e680e44c6 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Date: Fri, 3 May 2024 08:27:22 -0700 Subject: [PATCH 16/17] Enable explicit use of key tuples (instead of *Indexer objects) in indexing adapters and explicitly indexed arrays (#8870) * pass key tuple to indexing adapters and explicitly indexed arrays * update indexing in StackedBytesArray * Update indexing in StackedBytesArray * Add _IndexerKey type to _typing.py * Update indexing in StackedBytesArray * use tuple indexing in test_backend_array_deprecation_warning * Add support for CompatIndexedTuple in explicit indexing adapter This commit updates the `explicit_indexing_adapter` function to accept both `ExplicitIndexer` and the new `CompatIndexedTuple`. The `CompatIndexedTuple` is designed to facilitate the transition towards using raw tuples by carrying additional metadata about the indexing type (basic, vectorized, or outer). * remove unused code * type hint fixes * fix docstrings * fix tests * fix docstrings * Apply suggestions from code review Co-authored-by: Deepak Cherian * update docstrings and pass tuples directly * Some test cleanup * update docstring * use `BasicIndexer` instead of `CompatIndexedTuple` * support explicit indexing with tuples * fix mypy errors * remove unused IndexerMaker * Update LazilyIndexedArray._updated_key to support explicit indexing with tuples --------- Co-authored-by: Deepak Cherian Co-authored-by: Deepak Cherian --- xarray/coding/strings.py | 20 +- xarray/coding/variables.py | 6 +- xarray/core/indexing.py | 280 ++++++++++++++++------------ xarray/namedarray/_typing.py | 1 + xarray/tests/__init__.py | 10 - xarray/tests/test_backends.py | 2 +- xarray/tests/test_coding_strings.py | 15 +- xarray/tests/test_dataset.py | 32 ++-- xarray/tests/test_indexing.py | 45 +++-- 9 files changed, 212 insertions(+), 199 deletions(-) diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index db95286f6aa..6df92c256b9 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -17,6 +17,7 @@ from xarray.core import indexing from xarray.core.utils import module_available from xarray.core.variable import Variable +from xarray.namedarray._typing import _IndexerKey from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import is_chunked_array @@ -220,8 +221,7 @@ class StackedBytesArray(indexing.ExplicitlyIndexedNDArrayMixin): """Wrapper around array-like objects to create a new indexable object where values, when accessed, are automatically stacked along the last dimension. - >>> indexer = indexing.BasicIndexer((slice(None),)) - >>> StackedBytesArray(np.array(["a", "b", "c"], dtype="S1"))[indexer] + >>> StackedBytesArray(np.array(["a", "b", "c"], dtype="S1"))[(slice(None),)] array(b'abc', dtype='|S3') """ @@ -240,7 +240,7 @@ def __init__(self, array): @property def dtype(self): - return np.dtype("S" + str(self.array.shape[-1])) + return np.dtype(f"S{str(self.array.shape[-1])}") @property def shape(self) -> tuple[int, ...]: @@ -249,15 +249,17 @@ def shape(self) -> tuple[int, ...]: def __repr__(self): return f"{type(self).__name__}({self.array!r})" - def _vindex_get(self, key): + def _vindex_get(self, key: _IndexerKey): return _numpy_char_to_bytes(self.array.vindex[key]) - def _oindex_get(self, key): + def _oindex_get(self, key: _IndexerKey): return _numpy_char_to_bytes(self.array.oindex[key]) - def __getitem__(self, key): + def __getitem__(self, key: _IndexerKey): + from xarray.core.indexing import BasicIndexer + # require slicing the last dimension completely - key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim)) - if key.tuple[-1] != slice(None): + indexer = indexing.expanded_indexer(key, self.array.ndim) + if indexer[-1] != slice(None): raise IndexError("too many indices") - return _numpy_char_to_bytes(self.array[key]) + return _numpy_char_to_bytes(self.array[BasicIndexer(indexer)]) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index d31cb6e626a..98bbbbaeb2c 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -99,8 +99,7 @@ class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin): >>> NativeEndiannessArray(x).dtype dtype('int16') - >>> indexer = indexing.BasicIndexer((slice(None),)) - >>> NativeEndiannessArray(x)[indexer].dtype + >>> NativeEndiannessArray(x)[(slice(None),)].dtype dtype('int16') """ @@ -137,8 +136,7 @@ class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin): >>> BoolTypeArray(x).dtype dtype('bool') - >>> indexer = indexing.BasicIndexer((slice(None),)) - >>> BoolTypeArray(x)[indexer].dtype + >>> BoolTypeArray(x)[(slice(None),)].dtype dtype('bool') """ diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 7d6191883e1..2b8cd202e4e 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -10,7 +10,7 @@ from dataclasses import dataclass, field from datetime import timedelta from html import escape -from typing import TYPE_CHECKING, Any, Callable, overload +from typing import TYPE_CHECKING, Any, Callable, Literal, overload import numpy as np import pandas as pd @@ -36,7 +36,7 @@ from xarray.core.indexes import Index from xarray.core.variable import Variable - from xarray.namedarray._typing import _Shape, duckarray + from xarray.namedarray._typing import _IndexerKey, _Shape, duckarray from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -324,13 +324,13 @@ class ExplicitIndexer: __slots__ = ("_key",) - def __init__(self, key: tuple[Any, ...]): + def __init__(self, key: _IndexerKey): if type(self) is ExplicitIndexer: raise TypeError("cannot instantiate base ExplicitIndexer objects") self._key = tuple(key) @property - def tuple(self) -> tuple[Any, ...]: + def tuple(self) -> _IndexerKey: return self._key def __repr__(self) -> str: @@ -516,30 +516,29 @@ class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed): __slots__ = () def get_duck_array(self): - key = BasicIndexer((slice(None),) * self.ndim) - return self[key] + return self[(slice(None),) * self.ndim] def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: # This is necessary because we apply the indexing key in self.get_duck_array() # Note this is the base class for all lazy indexing classes return np.asarray(self.get_duck_array(), dtype=dtype) - def _oindex_get(self, indexer: OuterIndexer): + def _oindex_get(self, indexer: _IndexerKey): raise NotImplementedError( f"{self.__class__.__name__}._oindex_get method should be overridden" ) - def _vindex_get(self, indexer: VectorizedIndexer): + def _vindex_get(self, indexer: _IndexerKey): raise NotImplementedError( f"{self.__class__.__name__}._vindex_get method should be overridden" ) - def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + def _oindex_set(self, indexer: _IndexerKey, value: Any) -> None: raise NotImplementedError( f"{self.__class__.__name__}._oindex_set method should be overridden" ) - def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + def _vindex_set(self, indexer: _IndexerKey, value: Any) -> None: raise NotImplementedError( f"{self.__class__.__name__}._vindex_set method should be overridden" ) @@ -575,9 +574,9 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: def get_duck_array(self): return self.array.get_duck_array() - def __getitem__(self, key: Any): - key = expanded_indexer(key, self.ndim) - indexer = self.indexer_cls(key) + def __getitem__(self, key: _IndexerKey | slice): + _key = expanded_indexer(key, self.ndim) + indexer = self.indexer_cls(_key) result = apply_indexer(self.array, indexer) @@ -623,8 +622,13 @@ def __init__(self, array: Any, key: ExplicitIndexer | None = None): self.array = as_indexable(array) self.key = key - def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer: - iter_new_key = iter(expanded_indexer(new_key.tuple, self.ndim)) + def _updated_key( + self, new_key: ExplicitIndexer | _IndexerKey + ) -> BasicIndexer | OuterIndexer: + _new_key_tuple = ( + new_key.tuple if isinstance(new_key, ExplicitIndexer) else new_key + ) + iter_new_key = iter(expanded_indexer(_new_key_tuple, self.ndim)) full_key = [] for size, k in zip(self.array.shape, self.key.tuple): if isinstance(k, integer_types): @@ -673,31 +677,29 @@ def get_duck_array(self): def transpose(self, order): return LazilyVectorizedIndexedArray(self.array, self.key).transpose(order) - def _oindex_get(self, indexer: OuterIndexer): + def _oindex_get(self, indexer: _IndexerKey): return type(self)(self.array, self._updated_key(indexer)) - def _vindex_get(self, indexer: VectorizedIndexer): + def _vindex_get(self, indexer: _IndexerKey): array = LazilyVectorizedIndexedArray(self.array, self.key) return array.vindex[indexer] - def __getitem__(self, indexer: ExplicitIndexer): - self._check_and_raise_if_non_basic_indexer(indexer) + def __getitem__(self, indexer: _IndexerKey): return type(self)(self.array, self._updated_key(indexer)) - def _vindex_set(self, key: VectorizedIndexer, value: Any) -> None: + def _vindex_set(self, key: _IndexerKey, value: Any) -> None: raise NotImplementedError( "Lazy item assignment with the vectorized indexer is not yet " "implemented. Load your data first by .load() or compute()." ) - def _oindex_set(self, key: OuterIndexer, value: Any) -> None: - full_key = self._updated_key(key) - self.array.oindex[full_key] = value + def _oindex_set(self, key: _IndexerKey, value: Any) -> None: + full_key = self._updated_key(OuterIndexer(key)) + self.array.oindex[full_key.tuple] = value - def __setitem__(self, key: BasicIndexer, value: Any) -> None: - self._check_and_raise_if_non_basic_indexer(key) - full_key = self._updated_key(key) - self.array[full_key] = value + def __setitem__(self, key: _IndexerKey, value: Any) -> None: + full_key = self._updated_key(BasicIndexer(key)) + self.array[full_key.tuple] = value def __repr__(self) -> str: return f"{type(self).__name__}(array={self.array!r}, key={self.key!r})" @@ -756,25 +758,25 @@ def get_duck_array(self): def _updated_key(self, new_key: ExplicitIndexer): return _combine_indexers(self.key, self.shape, new_key) - def _oindex_get(self, indexer: OuterIndexer): - return type(self)(self.array, self._updated_key(indexer)) + def _oindex_get(self, indexer: _IndexerKey): + return type(self)(self.array, self._updated_key(OuterIndexer(indexer))) - def _vindex_get(self, indexer: VectorizedIndexer): - return type(self)(self.array, self._updated_key(indexer)) + def _vindex_get(self, indexer: _IndexerKey): + return type(self)(self.array, self._updated_key(VectorizedIndexer(indexer))) + + def __getitem__(self, indexer: _IndexerKey): - def __getitem__(self, indexer: ExplicitIndexer): - self._check_and_raise_if_non_basic_indexer(indexer) # If the indexed array becomes a scalar, return LazilyIndexedArray - if all(isinstance(ind, integer_types) for ind in indexer.tuple): - key = BasicIndexer(tuple(k[indexer.tuple] for k in self.key.tuple)) + if all(isinstance(ind, integer_types) for ind in indexer): + key = BasicIndexer(tuple(k[indexer] for k in self.key.tuple)) return LazilyIndexedArray(self.array, key) - return type(self)(self.array, self._updated_key(indexer)) + return type(self)(self.array, self._updated_key(BasicIndexer(indexer))) def transpose(self, order): key = VectorizedIndexer(tuple(k.transpose(order) for k in self.key.tuple)) return type(self)(self.array, key) - def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + def __setitem__(self, indexer: _IndexerKey, value: Any) -> None: raise NotImplementedError( "Lazy item assignment with the vectorized indexer is not yet " "implemented. Load your data first by .load() or compute()." @@ -807,29 +809,27 @@ def _ensure_copied(self): def get_duck_array(self): return self.array.get_duck_array() - def _oindex_get(self, indexer: OuterIndexer): + def _oindex_get(self, indexer: _IndexerKey): return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) - def _vindex_get(self, indexer: VectorizedIndexer): + def _vindex_get(self, indexer: _IndexerKey): return type(self)(_wrap_numpy_scalars(self.array.vindex[indexer])) - def __getitem__(self, indexer: ExplicitIndexer): - self._check_and_raise_if_non_basic_indexer(indexer) + def __getitem__(self, indexer: _IndexerKey): return type(self)(_wrap_numpy_scalars(self.array[indexer])) def transpose(self, order): return self.array.transpose(order) - def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + def _vindex_set(self, indexer: _IndexerKey, value: Any) -> None: self._ensure_copied() self.array.vindex[indexer] = value - def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + def _oindex_set(self, indexer: _IndexerKey, value: Any) -> None: self._ensure_copied() self.array.oindex[indexer] = value - def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: - self._check_and_raise_if_non_basic_indexer(indexer) + def __setitem__(self, indexer: _IndexerKey, value: Any) -> None: self._ensure_copied() self.array[indexer] = value @@ -857,27 +857,25 @@ def get_duck_array(self): self._ensure_cached() return self.array.get_duck_array() - def _oindex_get(self, indexer: OuterIndexer): + def _oindex_get(self, indexer: _IndexerKey): return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) - def _vindex_get(self, indexer: VectorizedIndexer): + def _vindex_get(self, indexer: _IndexerKey): return type(self)(_wrap_numpy_scalars(self.array.vindex[indexer])) - def __getitem__(self, indexer: ExplicitIndexer): - self._check_and_raise_if_non_basic_indexer(indexer) + def __getitem__(self, indexer: _IndexerKey): return type(self)(_wrap_numpy_scalars(self.array[indexer])) def transpose(self, order): return self.array.transpose(order) - def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + def _vindex_set(self, indexer: _IndexerKey, value: Any) -> None: self.array.vindex[indexer] = value - def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + def _oindex_set(self, indexer: _IndexerKey, value: Any) -> None: self.array.oindex[indexer] = value - def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: - self._check_and_raise_if_non_basic_indexer(indexer) + def __setitem__(self, indexer: _IndexerKey, value: Any) -> None: self.array[indexer] = value @@ -1040,29 +1038,63 @@ def explicit_indexing_adapter( return result +class CompatIndexedTuple(tuple): + """ + A tuple subclass used to transition existing backend implementations towards the use of raw tuples + for indexing by carrying additional metadata about the type of indexing being + performed ('basic', 'vectorized', or 'outer'). This class serves as a bridge, allowing + backend arrays that currently expect this metadata to function correctly while + maintaining the outward behavior of a regular tuple. + + This class is particularly useful during the phase where the backend implementations are + not yet capable of directly accepting raw tuples without additional context about + the indexing type. It ensures that these backends can still correctly interpret and + process indexing operations by providing them with the necessary contextual information. + """ + + def __new__(cls, iterable, indexer_type: Literal["basic", "vectorized", "outer"]): + obj = super().__new__(cls, iterable) + obj.indexer_type = indexer_type # type: ignore[attr-defined] + return obj + + def __repr__(self): + return f"CompatIndexedTuple({super().__repr__()}, indexer_type='{self.indexer_type}')" + + def apply_indexer(indexable, indexer: ExplicitIndexer): """Apply an indexer to an indexable object.""" if isinstance(indexer, VectorizedIndexer): - return indexable.vindex[indexer] + return indexable.vindex[CompatIndexedTuple(indexer.tuple, "vectorized")] elif isinstance(indexer, OuterIndexer): - return indexable.oindex[indexer] + return indexable.oindex[CompatIndexedTuple(indexer.tuple, "outer")] else: - return indexable[indexer] + return indexable[CompatIndexedTuple(indexer.tuple, "basic")] def set_with_indexer(indexable, indexer: ExplicitIndexer, value: Any) -> None: """Set values in an indexable object using an indexer.""" if isinstance(indexer, VectorizedIndexer): - indexable.vindex[indexer] = value + indexable.vindex[indexer.tuple] = value elif isinstance(indexer, OuterIndexer): - indexable.oindex[indexer] = value + indexable.oindex[indexer.tuple] = value else: - indexable[indexer] = value + indexable[indexer.tuple] = value def decompose_indexer( - indexer: ExplicitIndexer, shape: _Shape, indexing_support: IndexingSupport + indexer: ExplicitIndexer | CompatIndexedTuple, + shape: _Shape, + indexing_support: IndexingSupport, ) -> tuple[ExplicitIndexer, ExplicitIndexer]: + if isinstance(indexer, CompatIndexedTuple): + # recreate the indexer object from the tuple and the type of indexing. + # This is necessary to ensure that the backend array can correctly interpret the indexing operation. + if indexer.indexer_type == "vectorized": # type: ignore[attr-defined] + indexer = VectorizedIndexer(indexer) + elif indexer.indexer_type == "outer": # type: ignore[attr-defined] + indexer = OuterIndexer(indexer) + else: + indexer = BasicIndexer(indexer) if isinstance(indexer, VectorizedIndexer): return _decompose_vectorized_indexer(indexer, shape, indexing_support) if isinstance(indexer, (BasicIndexer, OuterIndexer)): @@ -1131,10 +1163,10 @@ def _decompose_vectorized_indexer( >>> array = np.arange(36).reshape(6, 6) >>> backend_indexer = OuterIndexer((np.array([0, 1, 3]), np.array([2, 3]))) >>> # load subslice of the array - ... array = NumpyIndexingAdapter(array).oindex[backend_indexer] + ... array = NumpyIndexingAdapter(array).oindex[backend_indexer.tuple] >>> np_indexer = VectorizedIndexer((np.array([0, 2, 1]), np.array([0, 1, 0]))) >>> # vectorized indexing for on-memory np.ndarray. - ... NumpyIndexingAdapter(array).vindex[np_indexer] + ... NumpyIndexingAdapter(array).vindex[np_indexer.tuple] array([ 2, 21, 8]) """ assert isinstance(indexer, VectorizedIndexer) @@ -1213,10 +1245,10 @@ def _decompose_outer_indexer( >>> array = np.arange(36).reshape(6, 6) >>> backend_indexer = BasicIndexer((slice(0, 3), slice(2, 4))) >>> # load subslice of the array - ... array = NumpyIndexingAdapter(array)[backend_indexer] + ... array = NumpyIndexingAdapter(array)[backend_indexer.tuple] >>> np_indexer = OuterIndexer((np.array([0, 2, 1]), np.array([0, 1, 0]))) >>> # outer indexing for on-memory np.ndarray. - ... NumpyIndexingAdapter(array).oindex[np_indexer] + ... NumpyIndexingAdapter(array).oindex[np_indexer.tuple] array([[ 2, 3, 2], [14, 15, 14], [ 8, 9, 8]]) @@ -1520,25 +1552,28 @@ def __init__(self, array): def transpose(self, order): return self.array.transpose(order) - def _oindex_get(self, indexer: OuterIndexer): - key = _outer_to_numpy_indexer(indexer, self.array.shape) + def _oindex_get(self, indexer: _IndexerKey): + key = _outer_to_numpy_indexer(OuterIndexer(indexer), self.array.shape) return self.array[key] - def _vindex_get(self, indexer: VectorizedIndexer): + def _vindex_get(self, indexer: _IndexerKey): array = NumpyVIndexAdapter(self.array) - return array[indexer.tuple] + return array[indexer] - def __getitem__(self, indexer: ExplicitIndexer): - self._check_and_raise_if_non_basic_indexer(indexer) + def __getitem__(self, indexer: _IndexerKey | ExplicitIndexer): array = self.array # We want 0d slices rather than scalars. This is achieved by # appending an ellipsis (see # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). - key = indexer.tuple + (Ellipsis,) + key = ( + indexer.tuple + if isinstance(indexer, ExplicitIndexer) + else indexer + (Ellipsis,) + ) return array[key] - def _safe_setitem(self, array, key: tuple[Any, ...], value: Any) -> None: + def _safe_setitem(self, array, key: _IndexerKey, value: Any) -> None: try: array[key] = value except ValueError as exc: @@ -1551,21 +1586,24 @@ def _safe_setitem(self, array, key: tuple[Any, ...], value: Any) -> None: else: raise exc - def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: - key = _outer_to_numpy_indexer(indexer, self.array.shape) + def _oindex_set(self, indexer: _IndexerKey, value: Any) -> None: + key = _outer_to_numpy_indexer(OuterIndexer(indexer), self.array.shape) self._safe_setitem(self.array, key, value) - def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + def _vindex_set(self, indexer: _IndexerKey, value: Any) -> None: array = NumpyVIndexAdapter(self.array) - self._safe_setitem(array, indexer.tuple, value) + self._safe_setitem(array, indexer, value) - def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: - self._check_and_raise_if_non_basic_indexer(indexer) + def __setitem__(self, indexer: _IndexerKey | ExplicitIndexer, value: Any) -> None: array = self.array # We want 0d slices rather than scalars. This is achieved by # appending an ellipsis (see # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). - key = indexer.tuple + (Ellipsis,) + key = ( + indexer.tuple + if isinstance(indexer, ExplicitIndexer) + else indexer + (Ellipsis,) + ) self._safe_setitem(array, key, value) @@ -1594,30 +1632,28 @@ def __init__(self, array): ) self.array = array - def _oindex_get(self, indexer: OuterIndexer): + def _oindex_get(self, indexer: _IndexerKey): # manual orthogonal indexing (implemented like DaskIndexingAdapter) - key = indexer.tuple + value = self.array - for axis, subkey in reversed(list(enumerate(key))): + for axis, subkey in reversed(list(enumerate(indexer))): value = value[(slice(None),) * axis + (subkey, Ellipsis)] return value - def _vindex_get(self, indexer: VectorizedIndexer): + def _vindex_get(self, indexer: _IndexerKey): raise TypeError("Vectorized indexing is not supported") - def __getitem__(self, indexer: ExplicitIndexer): - self._check_and_raise_if_non_basic_indexer(indexer) - return self.array[indexer.tuple] + def __getitem__(self, indexer: _IndexerKey): + return self.array[indexer] - def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: - self.array[indexer.tuple] = value + def _oindex_set(self, indexer: _IndexerKey, value: Any) -> None: + self.array[indexer] = value - def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + def _vindex_set(self, indexer: _IndexerKey, value: Any) -> None: raise TypeError("Vectorized indexing is not supported") - def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: - self._check_and_raise_if_non_basic_indexer(indexer) - self.array[indexer.tuple] = value + def __setitem__(self, indexer: _IndexerKey, value: Any) -> None: + self.array[indexer] = value def transpose(self, order): xp = self.array.__array_namespace__() @@ -1635,38 +1671,35 @@ def __init__(self, array): """ self.array = array - def _oindex_get(self, indexer: OuterIndexer): - key = indexer.tuple + def _oindex_get(self, indexer: _IndexerKey): try: - return self.array[key] + return self.array[indexer] except NotImplementedError: # manual orthogonal indexing value = self.array - for axis, subkey in reversed(list(enumerate(key))): + for axis, subkey in reversed(list(enumerate(indexer))): value = value[(slice(None),) * axis + (subkey,)] return value - def _vindex_get(self, indexer: VectorizedIndexer): - return self.array.vindex[indexer.tuple] + def _vindex_get(self, indexer: _IndexerKey): + return self.array.vindex[indexer] - def __getitem__(self, indexer: ExplicitIndexer): - self._check_and_raise_if_non_basic_indexer(indexer) - return self.array[indexer.tuple] + def __getitem__(self, indexer: _IndexerKey): + return self.array[indexer] - def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: - num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in indexer.tuple) + def _oindex_set(self, indexer: _IndexerKey, value: Any) -> None: + num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in indexer) if num_non_slices > 1: raise NotImplementedError( "xarray can't set arrays with multiple " "array indices to dask yet." ) - self.array[indexer.tuple] = value + self.array[indexer] = value - def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: - self.array.vindex[indexer.tuple] = value + def _vindex_set(self, indexer: _IndexerKey, value: Any) -> None: + self.array.vindex[indexer] = value - def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: - self._check_and_raise_if_non_basic_indexer(indexer) - self.array[indexer.tuple] = value + def __setitem__(self, indexer: _IndexerKey, value: Any) -> None: + self.array[indexer] = value def transpose(self, order): return self.array.transpose(order) @@ -1728,13 +1761,14 @@ def _convert_scalar(self, item): # a NumPy array. return to_0d_array(item) - def _prepare_key(self, key: tuple[Any, ...]) -> tuple[Any, ...]: - if isinstance(key, tuple) and len(key) == 1: + def _prepare_key(self, key: ExplicitIndexer | _IndexerKey) -> _IndexerKey: + _key = key.tuple if isinstance(key, ExplicitIndexer) else key + if isinstance(_key, tuple) and len(_key) == 1: # unpack key so it can index a pandas.Index object (pandas.Index # objects don't like tuples) - (key,) = key + (_key,) = _key - return key + return _key def _handle_result( self, result: Any @@ -1751,7 +1785,7 @@ def _handle_result( return self._convert_scalar(result) def _oindex_get( - self, indexer: OuterIndexer + self, indexer: _IndexerKey ) -> ( PandasIndexingAdapter | NumpyIndexingAdapter @@ -1759,7 +1793,7 @@ def _oindex_get( | np.datetime64 | np.timedelta64 ): - key = self._prepare_key(indexer.tuple) + key = self._prepare_key(indexer) if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional indexable = NumpyIndexingAdapter(np.asarray(self)) @@ -1770,7 +1804,7 @@ def _oindex_get( return self._handle_result(result) def _vindex_get( - self, indexer: VectorizedIndexer + self, indexer: _IndexerKey ) -> ( PandasIndexingAdapter | NumpyIndexingAdapter @@ -1778,7 +1812,7 @@ def _vindex_get( | np.datetime64 | np.timedelta64 ): - key = self._prepare_key(indexer.tuple) + key = self._prepare_key(indexer) if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional indexable = NumpyIndexingAdapter(np.asarray(self)) @@ -1789,7 +1823,7 @@ def _vindex_get( return self._handle_result(result) def __getitem__( - self, indexer: ExplicitIndexer + self, indexer: _IndexerKey ) -> ( PandasIndexingAdapter | NumpyIndexingAdapter @@ -1797,7 +1831,7 @@ def __getitem__( | np.datetime64 | np.timedelta64 ): - key = self._prepare_key(indexer.tuple) + key = self._prepare_key(indexer) if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional indexable = NumpyIndexingAdapter(np.asarray(self)) @@ -1862,7 +1896,7 @@ def _convert_scalar(self, item): return super()._convert_scalar(item) def _oindex_get( - self, indexer: OuterIndexer + self, indexer: _IndexerKey ) -> ( PandasIndexingAdapter | NumpyIndexingAdapter @@ -1876,7 +1910,7 @@ def _oindex_get( return result def _vindex_get( - self, indexer: VectorizedIndexer + self, indexer: _IndexerKey ) -> ( PandasIndexingAdapter | NumpyIndexingAdapter @@ -1889,7 +1923,7 @@ def _vindex_get( result.level = self.level return result - def __getitem__(self, indexer: ExplicitIndexer): + def __getitem__(self, indexer: _IndexerKey): result = super().__getitem__(indexer) if isinstance(result, type(self)): result.level = self.level @@ -1911,7 +1945,7 @@ def _get_array_subset(self) -> np.ndarray: if self.size > threshold: pos = threshold // 2 indices = np.concatenate([np.arange(0, pos), np.arange(-pos, 0)]) - subset = self[OuterIndexer((indices,))] + subset = self[(indices,)] else: subset = self diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index b715973814f..243c2382472 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -95,6 +95,7 @@ def dtype(self) -> _DType_co: ... _IndexKey = Union[int, slice, "ellipsis"] _IndexKeys = tuple[Union[_IndexKey], ...] # tuple[Union[_IndexKey, None], ...] _IndexKeyLike = Union[_IndexKey, _IndexKeys] +_IndexerKey = tuple[Any, ...] _AttrsLike = Union[Mapping[Any, Any], None] diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 23fd590f4dc..64a879369f8 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -226,16 +226,6 @@ def __getitem__(self, key): return key -class IndexerMaker: - def __init__(self, indexer_cls): - self._indexer_cls = indexer_cls - - def __getitem__(self, key): - if not isinstance(key, tuple): - key = (key,) - return self._indexer_cls(key) - - def source_ndarray(array): """Given an ndarray, return the base object which holds its memory, or the object itself. diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index d7471ecbaf9..eb5e2ef6cf0 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -5815,7 +5815,7 @@ def _getitem(self, key): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - la.vindex[indexer].get_duck_array() + la.vindex[indexer.tuple].get_duck_array() captured = capsys.readouterr() assert len(w) == 1 diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index 51f63ea72dd..0feac5b15eb 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -7,9 +7,7 @@ from xarray import Variable from xarray.coding import strings -from xarray.core import indexing from xarray.tests import ( - IndexerMaker, assert_array_equal, assert_identical, requires_dask, @@ -150,10 +148,9 @@ def test_StackedBytesArray() -> None: assert len(actual) == len(expected) assert_array_equal(expected, actual) - B = IndexerMaker(indexing.BasicIndexer) - assert_array_equal(expected[:1], actual[B[:1]]) + assert_array_equal(expected[:1], actual[(slice(1),)]) with pytest.raises(IndexError): - actual[B[:, :2]] + actual[slice(None), slice(2)] def test_StackedBytesArray_scalar() -> None: @@ -168,10 +165,8 @@ def test_StackedBytesArray_scalar() -> None: with pytest.raises(TypeError): len(actual) np.testing.assert_array_equal(expected, actual) - - B = IndexerMaker(indexing.BasicIndexer) with pytest.raises(IndexError): - actual[B[:2]] + actual[(slice(2),)] def test_StackedBytesArray_vectorized_indexing() -> None: @@ -179,9 +174,7 @@ def test_StackedBytesArray_vectorized_indexing() -> None: stacked = strings.StackedBytesArray(array) expected = np.array([[b"abc", b"def"], [b"def", b"abc"]]) - V = IndexerMaker(indexing.VectorizedIndexer) - indexer = V[np.array([[0, 1], [1, 0]])] - actual = stacked.vindex[indexer] + actual = stacked.vindex[(np.array([[0, 1], [1, 0]]),)] assert_array_equal(actual, expected) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 584776197e3..ecca8c0c79e 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -244,7 +244,7 @@ def get_array(self): return self.array def __getitem__(self, key): - return self.array[key.tuple] + return self.array[(key if isinstance(key, tuple) else key.tuple)] class AccessibleAsDuckArrayDataStore(backends.InMemoryDataStore): @@ -5096,28 +5096,26 @@ def test_lazy_load(self) -> None: ds.isel(time=10) ds.isel(time=slice(10), dim1=[0]).isel(dim1=0, dim2=-1) - def test_lazy_load_duck_array(self) -> None: + @pytest.mark.parametrize("decode_cf", [True, False]) + def test_lazy_load_duck_array(self, decode_cf) -> None: store = AccessibleAsDuckArrayDataStore() create_test_data().dump_to_store(store) - for decode_cf in [True, False]: - ds = open_dataset(store, decode_cf=decode_cf) - with pytest.raises(UnexpectedDataAccess): - ds["var1"].values + ds = open_dataset(store, decode_cf=decode_cf) + with pytest.raises(UnexpectedDataAccess): + ds["var1"].values - # these should not raise UnexpectedDataAccess: - ds.var1.data - ds.isel(time=10) - ds.isel(time=slice(10), dim1=[0]).isel(dim1=0, dim2=-1) - repr(ds) + # these should not raise UnexpectedDataAccess: + ds.var1.data + ds.isel(time=10) + ds.isel(time=slice(10), dim1=[0]).isel(dim1=0, dim2=-1) + repr(ds) - # preserve the duck array type and don't cast to array - assert isinstance(ds["var1"].load().data, DuckArrayWrapper) - assert isinstance( - ds["var1"].isel(dim2=0, dim1=0).load().data, DuckArrayWrapper - ) + # preserve the duck array type and don't cast to array + assert isinstance(ds["var1"].load().data, DuckArrayWrapper) + assert isinstance(ds["var1"].isel(dim2=0, dim1=0).load().data, DuckArrayWrapper) - ds.close() + ds.close() def test_dropna(self) -> None: x = np.random.randn(4, 4) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index f019d3c789c..b5da4a75439 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -12,7 +12,6 @@ from xarray.core.indexes import PandasIndex, PandasMultiIndex from xarray.core.types import T_Xarray from xarray.tests import ( - IndexerMaker, ReturnItem, assert_array_equal, assert_identical, @@ -20,8 +19,6 @@ requires_dask, ) -B = IndexerMaker(indexing.BasicIndexer) - class TestIndexCallable: def test_getitem(self): @@ -433,7 +430,7 @@ def test_lazily_indexed_array_vindex_setitem(self) -> None: NotImplementedError, match=r"Lazy item assignment with the vectorized indexer is not yet", ): - lazy.vindex[indexer] = 0 + lazy.vindex[indexer.tuple] = 0 @pytest.mark.parametrize( "indexer_class, key, value", @@ -449,10 +446,10 @@ def test_lazily_indexed_array_setitem(self, indexer_class, key, value) -> None: if indexer_class is indexing.BasicIndexer: indexer = indexer_class(key) - lazy[indexer] = value + lazy[indexer.tuple] = value elif indexer_class is indexing.OuterIndexer: indexer = indexer_class(key) - lazy.oindex[indexer] = value + lazy.oindex[indexer.tuple] = value assert_array_equal(original[key], value) @@ -461,16 +458,16 @@ class TestCopyOnWriteArray: def test_setitem(self) -> None: original = np.arange(10) wrapped = indexing.CopyOnWriteArray(original) - wrapped[B[:]] = 0 + wrapped[(slice(None),)] = 0 assert_array_equal(original, np.arange(10)) assert_array_equal(wrapped, np.zeros(10)) def test_sub_array(self) -> None: original = np.arange(10) wrapped = indexing.CopyOnWriteArray(original) - child = wrapped[B[:5]] + child = wrapped[(slice(5),)] assert isinstance(child, indexing.CopyOnWriteArray) - child[B[:]] = 0 + child[(slice(None),)] = 0 assert_array_equal(original, np.arange(10)) assert_array_equal(wrapped, np.arange(10)) assert_array_equal(child, np.zeros(5)) @@ -478,7 +475,7 @@ def test_sub_array(self) -> None: def test_index_scalar(self) -> None: # regression test for GH1374 x = indexing.CopyOnWriteArray(np.array(["foo", "bar"])) - assert np.array(x[B[0]][B[()]]) == "foo" + assert np.array(x[(0,)][()]) == "foo" class TestMemoryCachedArray: @@ -491,7 +488,7 @@ def test_wrapper(self) -> None: def test_sub_array(self) -> None: original = indexing.LazilyIndexedArray(np.arange(10)) wrapped = indexing.MemoryCachedArray(original) - child = wrapped[B[:5]] + child = wrapped[(slice(5),)] assert isinstance(child, indexing.MemoryCachedArray) assert_array_equal(child, np.arange(5)) assert isinstance(child.array, indexing.NumpyIndexingAdapter) @@ -500,13 +497,13 @@ def test_sub_array(self) -> None: def test_setitem(self) -> None: original = np.arange(10) wrapped = indexing.MemoryCachedArray(original) - wrapped[B[:]] = 0 + wrapped[(slice(None),)] = 0 assert_array_equal(original, np.zeros(10)) def test_index_scalar(self) -> None: # regression test for GH1374 x = indexing.MemoryCachedArray(np.array(["foo", "bar"])) - assert np.array(x[B[0]][B[()]]) == "foo" + assert np.array(x[(0,)][()]) == "foo" def test_base_explicit_indexer() -> None: @@ -615,7 +612,7 @@ def test_arrayize_vectorized_indexer(self) -> None: vindex, self.data.shape ) np.testing.assert_array_equal( - self.data.vindex[vindex], self.data.vindex[vindex_array] + self.data.vindex[vindex.tuple], self.data.vindex[vindex_array.tuple] ) actual = indexing._arrayize_vectorized_indexer( @@ -731,35 +728,35 @@ def test_decompose_indexers(shape, indexer_mode, indexing_support) -> None: # Dispatch to appropriate indexing method if indexer_mode.startswith("vectorized"): - expected = indexing_adapter.vindex[indexer] + expected = indexing_adapter.vindex[indexer.tuple] elif indexer_mode.startswith("outer"): - expected = indexing_adapter.oindex[indexer] + expected = indexing_adapter.oindex[indexer.tuple] else: - expected = indexing_adapter[indexer] # Basic indexing + expected = indexing_adapter[indexer.tuple] # Basic indexing if isinstance(backend_ind, indexing.VectorizedIndexer): - array = indexing_adapter.vindex[backend_ind] + array = indexing_adapter.vindex[backend_ind.tuple] elif isinstance(backend_ind, indexing.OuterIndexer): - array = indexing_adapter.oindex[backend_ind] + array = indexing_adapter.oindex[backend_ind.tuple] else: - array = indexing_adapter[backend_ind] + array = indexing_adapter[backend_ind.tuple] if len(np_ind.tuple) > 0: array_indexing_adapter = indexing.NumpyIndexingAdapter(array) if isinstance(np_ind, indexing.VectorizedIndexer): - array = array_indexing_adapter.vindex[np_ind] + array = array_indexing_adapter.vindex[np_ind.tuple] elif isinstance(np_ind, indexing.OuterIndexer): - array = array_indexing_adapter.oindex[np_ind] + array = array_indexing_adapter.oindex[np_ind.tuple] else: - array = array_indexing_adapter[np_ind] + array = array_indexing_adapter[np_ind.tuple] np.testing.assert_array_equal(expected, array) if not all(isinstance(k, indexing.integer_types) for k in np_ind.tuple): combined_ind = indexing._combine_indexers(backend_ind, shape, np_ind) assert isinstance(combined_ind, indexing.VectorizedIndexer) - array = indexing_adapter.vindex[combined_ind] + array = indexing_adapter.vindex[combined_ind.tuple] np.testing.assert_array_equal(expected, array) From 18c5c70c7c08414695f1f3abda86264f15fb88a5 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 May 2024 13:21:14 -0600 Subject: [PATCH 17/17] Trigger CI only if code files are modified. (#9006) * Trigger CI only if code files are modified. Fixes #8705 * Apply suggestions from code review Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --------- Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- .github/workflows/ci-additional.yaml | 6 ++++++ .github/workflows/ci.yaml | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index bc2eb8d2cac..49a9272e4f0 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -14,6 +14,12 @@ on: - 'properties/**' - 'xarray/**' - "backend-indexing" + paths: + - 'ci/**' + - '.github/**' + - '/*' # covers files such as `pyproject.toml` + - 'properties/**' + - 'xarray/**' workflow_dispatch: # allows you to trigger manually concurrency: diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ca9ef397962..a4b165db06c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -14,6 +14,12 @@ on: - 'properties/**' - 'xarray/**' - "backend-indexing" + paths: + - 'ci/**' + - '.github/**' + - '/*' # covers files such as `pyproject.toml` + - 'properties/**' + - 'xarray/**' workflow_dispatch: # allows you to trigger manually concurrency: