From a56a4076f4d7d8eb981f6c38c3ed624a9df7b560 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 21 Aug 2024 09:07:33 -0600 Subject: [PATCH] refactor GroupBy internals (#9389) * More tests * Refactoring GroupBy 1. Simplify ResolvedGrouper by moving logic to EncodedGroups 2. Stack outside ResolvedGrouper in GroupBy.__init__ to prepare for multi-variable GroupBy * bail on pandas 2.0 --- xarray/core/coordinates.py | 2 +- xarray/core/groupby.py | 230 +++++++++++++++-------------------- xarray/groupers.py | 112 +++++++++++++---- xarray/tests/__init__.py | 1 + xarray/tests/test_groupby.py | 54 +++++++- 5 files changed, 238 insertions(+), 161 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 251edd1fc6f..3b852b962bf 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -352,7 +352,7 @@ def _construct_direct( return obj @classmethod - def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: str) -> Self: + def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: """Wrap a pandas multi-index as Xarray coordinates (dimension + levels). The returned coordinates can be directly assigned to a diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index faeb0c538c3..833466ffe9e 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -23,7 +23,7 @@ from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( PandasIndex, - create_default_index_implicit, + PandasMultiIndex, filter_indexes_from_coords, ) from xarray.core.options import OPTIONS, _get_keep_attrs @@ -54,7 +54,7 @@ from xarray.core.dataset import Dataset from xarray.core.types import GroupIndex, GroupIndices, GroupKey from xarray.core.utils import Frozen - from xarray.groupers import Grouper + from xarray.groupers import EncodedGroups, Grouper def check_reduce_dims(reduce_dims, dimensions): @@ -273,16 +273,19 @@ class ResolvedGrouper(Generic[T_DataWithCoords]): obj: T_DataWithCoords # returned by factorize: - codes: DataArray = field(init=False, repr=False) - full_index: pd.Index = field(init=False, repr=False) - group_indices: GroupIndices = field(init=False, repr=False) - unique_coord: Variable | _DummyGroup = field(init=False, repr=False) + encoded: EncodedGroups = field(init=False, repr=False) - # _ensure_1d: - group1d: T_Group = field(init=False, repr=False) - stacked_obj: T_DataWithCoords = field(init=False, repr=False) - stacked_dim: Hashable | None = field(init=False, repr=False) - inserted_dims: list[Hashable] = field(init=False, repr=False) + @property + def full_index(self) -> pd.Index: + return self.encoded.full_index + + @property + def codes(self) -> DataArray: + return self.encoded.codes + + @property + def unique_coord(self) -> Variable | _DummyGroup: + return self.encoded.unique_coord def __post_init__(self) -> None: # This copy allows the BinGrouper.factorize() method @@ -294,20 +297,13 @@ def __post_init__(self) -> None: self.group = _resolve_group(self.obj, self.group) - ( - self.group1d, - self.stacked_obj, - self.stacked_dim, - self.inserted_dims, - ) = _ensure_1d(group=self.group, obj=self.obj) - - self.factorize() + self.encoded = self.grouper.factorize(self.group) @property def name(self) -> Hashable: """Name for the grouped coordinate after reduction.""" # the name has to come from unique_coord because we need `_bins` suffix for BinGrouper - (name,) = self.unique_coord.dims + (name,) = self.encoded.unique_coord.dims return name @property @@ -317,33 +313,7 @@ def size(self) -> int: def __len__(self) -> int: """Number of groups.""" - return len(self.full_index) - - @property - def dims(self): - return self.group1d.dims - - def factorize(self) -> None: - encoded = self.grouper.factorize(self.group1d) - - self.codes = encoded.codes - self.full_index = encoded.full_index - - if encoded.group_indices is not None: - self.group_indices = encoded.group_indices - else: - self.group_indices = tuple( - g - for g in _codes_to_group_indices(self.codes.data, len(self.full_index)) - if g - ) - if encoded.unique_coord is None: - unique_values = self.full_index[np.unique(encoded.codes)] - self.unique_coord = Variable( - dims=self.codes.name, data=unique_values, attrs=self.group.attrs - ) - else: - self.unique_coord = encoded.unique_coord + return len(self.encoded.full_index) def _validate_groupby_squeeze(squeeze: Literal[False]) -> None: @@ -428,31 +398,29 @@ class GroupBy(Generic[T_Xarray]): """ __slots__ = ( - "_full_index", - "_inserted_dims", - "_group", "_group_dim", - "_group_indices", - "_groups", "groupers", "_obj", "_restore_coord_dims", - "_stacked_dim", - "_unique_coord", + # cached properties + "_groups", "_dims", "_sizes", + "_len", # Save unstacked object for flox "_original_obj", - "_original_group", - "_bins", "_codes", + # stack nD vars + "group1d", + "_stacked_dim", + "_inserted_dims", + "encoded", ) _obj: T_Xarray groupers: tuple[ResolvedGrouper] _restore_coord_dims: bool _original_obj: T_Xarray - _original_group: T_Group _group_indices: GroupIndices _codes: DataArray _group_dim: Hashable @@ -460,6 +428,14 @@ class GroupBy(Generic[T_Xarray]): _groups: dict[GroupKey, GroupIndex] | None _dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None _sizes: Mapping[Hashable, int] | None + _len: int + + # _ensure_1d: + group1d: T_Group + _stacked_dim: Hashable | None + _inserted_dims: list[Hashable] + + encoded: EncodedGroups def __init__( self, @@ -479,26 +455,26 @@ def __init__( If True, also restore the dimension order of multi-dimensional coordinates. """ - self.groupers = groupers - self._original_obj = obj + self._restore_coord_dims = restore_coord_dims + self.groupers = groupers - (grouper,) = self.groupers - self._original_group = grouper.group + (grouper,) = groupers + self.encoded = grouper.encoded # specification for the groupby operation - self._obj = grouper.stacked_obj - self._restore_coord_dims = restore_coord_dims - - # These should generalize to multiple groupers - self._group_indices = grouper.group_indices - self._codes = self._maybe_unstack(grouper.codes) + # TODO: handle obj having variables that are not present on any of the groupers + # simple broadcasting fails for ExtensionArrays. + (self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = _ensure_1d( + group=self.encoded.codes, obj=obj + ) + (self._group_dim,) = self.group1d.dims - (self._group_dim,) = grouper.group1d.dims # cached attributes self._groups = None self._dims = None self._sizes = None + self._len = len(self.encoded.full_index) @property def sizes(self) -> Mapping[Hashable, int]: @@ -512,8 +488,7 @@ def sizes(self) -> Mapping[Hashable, int]: Dataset.sizes """ if self._sizes is None: - (grouper,) = self.groupers - index = self._group_indices[0] + index = self.encoded.group_indices[0] self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes @@ -546,24 +521,22 @@ def groups(self) -> dict[GroupKey, GroupIndex]: """ # provided to mimic pandas.groupby if self._groups is None: - (grouper,) = self.groupers - self._groups = dict(zip(grouper.unique_coord.values, self._group_indices)) + self._groups = dict( + zip(self.encoded.unique_coord.data, self.encoded.group_indices) + ) return self._groups def __getitem__(self, key: GroupKey) -> T_Xarray: """ Get DataArray or Dataset corresponding to a particular group label. """ - (grouper,) = self.groupers return self._obj.isel({self._group_dim: self.groups[key]}) def __len__(self) -> int: - (grouper,) = self.groupers - return grouper.size + return self._len def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]: - (grouper,) = self.groupers - return zip(grouper.unique_coord.data, self._iter_grouped()) + return zip(self.encoded.unique_coord.data, self._iter_grouped()) def __repr__(self) -> str: (grouper,) = self.groupers @@ -576,28 +549,20 @@ def __repr__(self) -> str: def _iter_grouped(self) -> Iterator[T_Xarray]: """Iterate over each element in this group""" - (grouper,) = self.groupers - for idx, indices in enumerate(self._group_indices): - yield self._obj.isel({self._group_dim: indices}) + for indices in self.encoded.group_indices: + if indices: + yield self._obj.isel({self._group_dim: indices}) def _infer_concat_args(self, applied_example): - from xarray.groupers import BinGrouper - (grouper,) = self.groupers if self._group_dim in applied_example.dims: - coord = grouper.group1d - positions = self._group_indices + coord = self.group1d + positions = self.encoded.group_indices else: - coord = grouper.unique_coord + coord = self.encoded.unique_coord positions = None (dim,) = coord.dims - if isinstance(grouper.group, _DummyGroup) and not isinstance( - grouper.grouper, BinGrouper - ): - # When binning we actually do set the index - coord = None - coord = getattr(coord, "variable", coord) - return coord, dim, positions + return dim, positions def _binary_op(self, other, f, reflexive=False): from xarray.core.dataarray import DataArray @@ -609,7 +574,7 @@ def _binary_op(self, other, f, reflexive=False): obj = self._original_obj name = grouper.name group = grouper.group - codes = self._codes + codes = self.encoded.codes dims = group.dims if isinstance(group, _DummyGroup): @@ -710,8 +675,8 @@ def _maybe_unstack(self, obj): """This gets called if we are applying on an array with a multidimensional group.""" (grouper,) = self.groupers - stacked_dim = grouper.stacked_dim - inserted_dims = grouper.inserted_dims + stacked_dim = self._stacked_dim + inserted_dims = self._inserted_dims if stacked_dim is not None and stacked_dim in obj.dims: obj = obj.unstack(stacked_dim) for dim in inserted_dims: @@ -797,7 +762,7 @@ def _flox_reduce( output_index = grouper.full_index result = xarray_reduce( obj.drop_vars(non_numeric.keys()), - self._codes, + self.encoded.codes, dim=parsed_dim, # pass RangeIndex as a hint to flox that `by` is already factorized expected_groups=(pd.RangeIndex(len(output_index)),), @@ -808,15 +773,27 @@ def _flox_reduce( # we did end up reducing over dimension(s) that are # in the grouped variable - group_dims = grouper.group.dims - if set(group_dims).issubset(set(parsed_dim)): - result = result.assign_coords( - Coordinates( - coords={name: (name, np.array(output_index))}, - indexes={name: PandasIndex(output_index, dim=name)}, + group_dims = set(grouper.group.dims) + new_coords = {} + if group_dims.issubset(set(parsed_dim)): + new_indexes = {} + for grouper in self.groupers: + output_index = grouper.full_index + if isinstance(output_index, pd.RangeIndex): + continue + name = grouper.name + new_coords[name] = IndexVariable( + dims=name, data=np.array(output_index), attrs=grouper.codes.attrs ) - ) - result = result.drop_vars(unindexed_dims) + index_cls = ( + PandasIndex + if not isinstance(output_index, pd.MultiIndex) + else PandasMultiIndex + ) + new_indexes[name] = index_cls(output_index, dim=name) + result = result.assign_coords( + Coordinates(new_coords, new_indexes) + ).drop_vars(unindexed_dims) # broadcast and restore non-numeric data variables (backcompat) for name, var in non_numeric.items(): @@ -986,7 +963,7 @@ def quantile( """ if dim is None: (grouper,) = self.groupers - dim = grouper.group1d.dims + dim = self.group1d.dims # Dataset.quantile does this, do it for flox to ensure same output. q = np.asarray(q, dtype=np.float64) @@ -1038,7 +1015,7 @@ def _first_or_last(self, op, skipna, keep_attrs): if all( isinstance(maybe_slice, slice) and (maybe_slice.stop == maybe_slice.start + 1) - for maybe_slice in self._group_indices + for maybe_slice in self.encoded.group_indices ): # NB. this is currently only used for reductions along an existing # dimension @@ -1087,8 +1064,7 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic): @property def dims(self) -> tuple[Hashable, ...]: if self._dims is None: - (grouper,) = self.groupers - index = self._group_indices[0] + index = self.encoded.group_indices[0] self._dims = self._obj.isel({self._group_dim: index}).dims return self._dims @@ -1097,8 +1073,7 @@ def _iter_grouped_shortcut(self): metadata """ var = self._obj.variable - (grouper,) = self.groupers - for idx, indices in enumerate(self._group_indices): + for idx, indices in enumerate(self.encoded.group_indices): yield var[{self._group_dim: indices}] def _concat_shortcut(self, applied, dim, positions=None): @@ -1109,14 +1084,12 @@ def _concat_shortcut(self, applied, dim, positions=None): # TODO: benbovy - explicit indexes: this fast implementation doesn't # create an explicit index for the stacked dim coordinate stacked = Variable.concat(applied, dim, shortcut=True) - - (grouper,) = self.groupers - reordered = _maybe_reorder(stacked, dim, positions, N=grouper.group.size) + reordered = _maybe_reorder(stacked, dim, positions, N=self.group1d.size) return self._obj._replace_maybe_drop_dims(reordered) def _restore_dim_order(self, stacked: DataArray) -> DataArray: (grouper,) = self.groupers - group = grouper.group1d + group = self.group1d def lookup_order(dimension): if dimension == grouper.name: @@ -1200,24 +1173,21 @@ def apply(self, func, shortcut=False, args=(), **kwargs): def _combine(self, applied, shortcut=False): """Recombine the applied objects like the original.""" applied_example, applied = peek_at(applied) - coord, dim, positions = self._infer_concat_args(applied_example) + dim, positions = self._infer_concat_args(applied_example) if shortcut: combined = self._concat_shortcut(applied, dim, positions) else: combined = concat(applied, dim) - (grouper,) = self.groupers - combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size) + combined = _maybe_reorder(combined, dim, positions, N=self.group1d.size) if isinstance(combined, type(self._obj)): # only restore dimension order for arrays combined = self._restore_dim_order(combined) # assign coord and index when the applied function does not return that coord - if coord is not None and dim not in applied_example.dims: - index, index_vars = create_default_index_implicit(coord) - indexes = {k: index for k in index_vars} - combined = combined._overwrite_indexes(indexes, index_vars) - combined = self._maybe_restore_empty_groups(combined) + if dim not in applied_example.dims: + combined = combined.assign_coords(self.encoded.coords) combined = self._maybe_unstack(combined) + combined = self._maybe_restore_empty_groups(combined) return combined def reduce( @@ -1297,8 +1267,7 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic): @property def dims(self) -> Frozen[Hashable, int]: if self._dims is None: - (grouper,) = self.groupers - index = self._group_indices[0] + index = self.encoded.group_indices[0] self._dims = self._obj.isel({self._group_dim: index}).dims return FrozenMappingWarningOnValuesAccess(self._dims) @@ -1362,17 +1331,14 @@ def apply(self, func, args=(), shortcut=None, **kwargs): def _combine(self, applied): """Recombine the applied objects like the original.""" applied_example, applied = peek_at(applied) - coord, dim, positions = self._infer_concat_args(applied_example) + dim, positions = self._infer_concat_args(applied_example) combined = concat(applied, dim) - (grouper,) = self.groupers - combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size) + combined = _maybe_reorder(combined, dim, positions, N=self.group1d.size) # assign coord when the applied function does not return that coord - if coord is not None and dim not in applied_example.dims: - index, index_vars = create_default_index_implicit(coord) - indexes = {k: index for k in index_vars} - combined = combined._overwrite_indexes(indexes, index_vars) - combined = self._maybe_restore_empty_groups(combined) + if dim not in applied_example.dims: + combined = combined.assign_coords(self.encoded.coords) combined = self._maybe_unstack(combined) + combined = self._maybe_restore_empty_groups(combined) return combined def reduce( diff --git a/xarray/groupers.py b/xarray/groupers.py index 98409dfe542..f70cad655e8 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -9,13 +9,14 @@ import datetime from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np import pandas as pd from xarray.coding.cftime_offsets import _new_to_legacy_freq from xarray.core import duck_array_ops +from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index @@ -35,7 +36,18 @@ RESAMPLE_DIM = "__resample_dim__" -@dataclass +def _coordinates_from_variable(variable: Variable) -> Coordinates: + from xarray.core.indexes import create_default_index_implicit + + (name,) = variable.dims + new_index, index_vars = create_default_index_implicit(variable) + indexes = {k: new_index for k in index_vars} + new_vars = new_index.create_variables() + new_vars[name].attrs = variable.attrs + return Coordinates(new_vars, indexes) + + +@dataclass(init=False) class EncodedGroups: """ Dataclass for storing intermediate values for GroupBy operation. @@ -57,18 +69,49 @@ class EncodedGroups: codes: DataArray full_index: pd.Index - group_indices: GroupIndices | None = field(default=None) - unique_coord: Variable | _DummyGroup | None = field(default=None) - - def __post_init__(self): - assert isinstance(self.codes, DataArray) - if self.codes.name is None: + group_indices: GroupIndices + unique_coord: Variable | _DummyGroup + coords: Coordinates + + def __init__( + self, + codes: DataArray, + full_index: pd.Index, + group_indices: GroupIndices | None = None, + unique_coord: Variable | _DummyGroup | None = None, + coords: Coordinates | None = None, + ): + from xarray.core.groupby import _codes_to_group_indices + + assert isinstance(codes, DataArray) + if codes.name is None: raise ValueError("Please set a name on the array you are grouping by.") - assert isinstance(self.full_index, pd.Index) - assert ( - isinstance(self.unique_coord, Variable | _DummyGroup) - or self.unique_coord is None - ) + self.codes = codes + assert isinstance(full_index, pd.Index) + self.full_index = full_index + + if group_indices is None: + self.group_indices = tuple( + g + for g in _codes_to_group_indices(codes.data.ravel(), len(full_index)) + if g + ) + else: + self.group_indices = group_indices + + if unique_coord is None: + unique_values = full_index[np.unique(codes)] + self.unique_coord = Variable( + dims=codes.name, data=unique_values, attrs=codes.attrs + ) + else: + self.unique_coord = unique_coord + + if coords is None: + assert not isinstance(self.unique_coord, _DummyGroup) + self.coords = _coordinates_from_variable(self.unique_coord) + else: + self.coords = coords class Grouper(ABC): @@ -111,11 +154,14 @@ class UniqueGrouper(Grouper): def group_as_index(self) -> pd.Index: """Caches the group DataArray as a pandas Index.""" if self._group_as_index is None: - self._group_as_index = self.group.to_index() + if self.group.ndim == 1: + self._group_as_index = self.group.to_index() + else: + self._group_as_index = pd.Index(np.array(self.group).ravel()) return self._group_as_index - def factorize(self, group1d: T_Group) -> EncodedGroups: - self.group = group1d + def factorize(self, group: T_Group) -> EncodedGroups: + self.group = group index = self.group_as_index is_unique_and_monotonic = isinstance(self.group, _DummyGroup) or ( @@ -138,14 +184,17 @@ def _factorize_unique(self) -> EncodedGroups: raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) - codes = self.group.copy(data=codes_) + codes = self.group.copy(data=codes_.reshape(self.group.shape)) unique_coord = Variable( dims=codes.name, data=unique_values, attrs=self.group.attrs ) full_index = pd.Index(unique_values) return EncodedGroups( - codes=codes, full_index=full_index, unique_coord=unique_coord + codes=codes, + full_index=full_index, + unique_coord=unique_coord, + coords=_coordinates_from_variable(unique_coord), ) def _factorize_dummy(self) -> EncodedGroups: @@ -156,20 +205,31 @@ def _factorize_dummy(self) -> EncodedGroups: group_indices: GroupIndices = tuple(slice(i, i + 1) for i in range(size)) size_range = np.arange(size) full_index: pd.Index + unique_coord: _DummyGroup | Variable if isinstance(self.group, _DummyGroup): codes = self.group.to_dataarray().copy(data=size_range) unique_coord = self.group full_index = pd.RangeIndex(self.group.size) + coords = Coordinates() else: codes = self.group.copy(data=size_range) unique_coord = self.group.variable.to_base_variable() - full_index = pd.Index(unique_coord.data) + full_index = self.group_as_index + if isinstance(full_index, pd.MultiIndex): + coords = Coordinates.from_pandas_multiindex( + full_index, dim=self.group.name + ) + else: + if TYPE_CHECKING: + assert isinstance(unique_coord, Variable) + coords = _coordinates_from_variable(unique_coord) return EncodedGroups( codes=codes, group_indices=group_indices, full_index=full_index, unique_coord=unique_coord, + coords=coords, ) @@ -231,7 +291,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: data = np.asarray(group.data) # Cast _DummyGroup data to array binned, self.bins = pd.cut( # type: ignore [call-overload] - data, + data.ravel(), bins=self.bins, right=self.right, labels=self.labels, @@ -254,13 +314,18 @@ def factorize(self, group: T_Group) -> EncodedGroups: unique_values = full_index[uniques[uniques != -1]] codes = DataArray( - binned_codes, getattr(group, "coords", None), name=new_dim_name + binned_codes.reshape(group.shape), + getattr(group, "coords", None), + name=new_dim_name, ) unique_coord = Variable( dims=new_dim_name, data=unique_values, attrs=group.attrs ) return EncodedGroups( - codes=codes, full_index=full_index, unique_coord=unique_coord + codes=codes, + full_index=full_index, + unique_coord=unique_coord, + coords=_coordinates_from_variable(unique_coord), ) @@ -373,13 +438,14 @@ def factorize(self, group: T_Group) -> EncodedGroups: unique_coord = Variable( dims=group.name, data=first_items.index, attrs=group.attrs ) - codes = group.copy(data=codes_) + codes = group.copy(data=codes_.reshape(group.shape)) return EncodedGroups( codes=codes, group_indices=group_indices, full_index=full_index, unique_coord=unique_coord, + coords=_coordinates_from_variable(unique_coord), ) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 0caab6e8247..b4d3871c229 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -134,6 +134,7 @@ def _importorskip( has_pint, requires_pint = _importorskip("pint") has_numexpr, requires_numexpr = _importorskip("numexpr") has_flox, requires_flox = _importorskip("flox") +has_pandas_ge_2_1, __ = _importorskip("pandas", "2.1") has_pandas_ge_2_2, __ = _importorskip("pandas", "2.2") has_pandas_3, requires_pandas_3 = _importorskip("pandas", "3.0.0.dev0") diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 6c9254966d9..7dbb0d5e59c 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -22,6 +22,7 @@ create_test_data, has_cftime, has_flox, + has_pandas_ge_2_1, requires_cftime, requires_dask, requires_flox, @@ -118,6 +119,13 @@ def test_multi_index_groupby_sum() -> None: actual = ds.stack(space=["x", "y"]).groupby("space").sum("z").unstack("space") assert_equal(expected, actual) + if not has_pandas_ge_2_1: + # the next line triggers a mysterious multiindex error on pandas 2.0 + return + + actual = ds.stack(space=["x", "y"]).groupby("space").sum(...).unstack("space") + assert_equal(expected, actual) + def test_groupby_da_datetime() -> None: # test groupby with a DataArray of dtype datetime for GH1132 @@ -806,6 +814,7 @@ def test_groupby_dataset_errors() -> None: data.groupby(data.coords["dim1"].to_index()) # type: ignore[arg-type] +@pytest.mark.parametrize("use_flox", [True, False]) @pytest.mark.parametrize( "by_func", [ @@ -813,7 +822,10 @@ def test_groupby_dataset_errors() -> None: pytest.param(lambda x: {x: UniqueGrouper()}, id="group-by-unique-grouper"), ], ) -def test_groupby_dataset_reduce_ellipsis(by_func) -> None: +@pytest.mark.parametrize("letters_as_coord", [True, False]) +def test_groupby_dataset_reduce_ellipsis( + by_func, use_flox: bool, letters_as_coord: bool +) -> None: data = Dataset( { "xy": (["x", "y"], np.random.randn(3, 4)), @@ -823,13 +835,18 @@ def test_groupby_dataset_reduce_ellipsis(by_func) -> None: } ) + if letters_as_coord: + data = data.set_coords("letters") + expected = data.mean("y") expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3}) gb = data.groupby(by_func("x")) - actual = gb.mean(...) + with xr.set_options(use_flox=use_flox): + actual = gb.mean(...) assert_allclose(expected, actual) - actual = gb.mean("y") + with xr.set_options(use_flox=use_flox): + actual = gb.mean("y") assert_allclose(expected, actual) letters = data["letters"] @@ -841,7 +858,8 @@ def test_groupby_dataset_reduce_ellipsis(by_func) -> None: } ) gb = data.groupby(by_func("letters")) - actual = gb.mean(...) + with xr.set_options(use_flox=use_flox): + actual = gb.mean(...) assert_allclose(expected, actual) @@ -1729,7 +1747,7 @@ def test_groupby_fastpath_for_monotonic(self, use_flox: bool) -> None: rev = array_rev.groupby("idx", squeeze=False) for gb in [fwd, rev]: - assert all([isinstance(elem, slice) for elem in gb._group_indices]) + assert all([isinstance(elem, slice) for elem in gb.encoded.group_indices]) with xr.set_options(use_flox=use_flox): assert_identical(fwd.sum(), array) @@ -2561,3 +2579,29 @@ def factorize(self, group) -> EncodedGroups: obj.groupby("time.year", time=YearGrouper()) with pytest.raises(ValueError): obj.groupby() + + +@pytest.mark.parametrize("use_flox", [True, False]) +def test_weather_data_resample(use_flox): + # from the docs + times = pd.date_range("2000-01-01", "2001-12-31", name="time") + annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28)) + + base = 10 + 15 * annual_cycle.reshape(-1, 1) + tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3) + tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3) + + ds = xr.Dataset( + { + "tmin": (("time", "location"), tmin_values), + "tmax": (("time", "location"), tmax_values), + }, + { + "time": ("time", times, {"time_key": "time_values"}), + "location": ("location", ["IA", "IN", "IL"], {"loc_key": "loc_value"}), + }, + ) + + with xr.set_options(use_flox=use_flox): + actual = ds.resample(time="1MS").mean() + assert "location" in actual._indexes