Skip to content

Commit

Permalink
(feat): Support for pandas ExtensionArray (#8723)
Browse files Browse the repository at this point in the history
* (feat): first pass supporting extension arrays

* (feat): categorical tests + functionality

* (feat): use multiple dispatch for unimplemented ops

* (feat): implement (not really) broadcasting

* (chore): add more `groupby` tests

* (fix): fix more groupby incompatibility

* (bug): fix unused categories

* (chore): refactor dispatched methods + tests

* (fix): shared type should check for extension arrays first and then fall back to numpy

* (refactor): tests moved

* (chore): more higher level tests

* (feat): to/from dataframe

* (chore): check for plum import

* (fix): `__setitem__`/`__getitem__`

* (chore): disallow stacking

* (fix): `pyproject.toml`

* (fix): `as_shared_type` fix

* (chore): add variable tests

* (fix): dask + categoricals

* (chore): notes/docs

* (chore): remove old testing file

* (chore): remove ocmmented out code

* (fix): import plum dispatch

* (refactor): use `is_extension_array_dtype` as much as possible

* (refactor): `extension_array`->`array` + move to `indexing`

* (refactor): change order of classes

* (chore): add small pyarrow test

* (fix): fix some mypy issues

* (fix): don't register unregisterable method

* (fix): appease mypy

* (fix): more sensible default implemetations allow most use without `plum`

* (fix): handling `pyarrow` tests

* (fix): actually do import correctly

* (fix): `reduce` condition

* (fix): column ordering for dataframes

* (refactor): remove encoding business

* (refactor): raise error for dask + extension array

* (fix): only wrap `ExtensionDuckArray` that has a `.array` which is a pandas extension array

* (fix): use duck array equality method, not pandas

* (refactor): bye plum!

* (fix): `and` to `or` for casting to `ExtensionDuckArray`

* (fix): check for class, not type

* (fix): only support native endianness

* (refactor): no need for superfluous checks in `_maybe_wrap_data`

* (chore): clean up docs to no longer reference `plum`

* (fix): no longer allow `ExtensionDuckArray` to wrap `ExtensionDuckArray`

* (refactor): move `implements` logic to `indexing`

* (refactor): `indexing.py` -> `extension_array.py`

* (refactor): `ExtensionDuckArray` -> `PandasExtensionArray`

* (fix): add writeable property

* (fix): don't check writeable for `PandasExtensionArray`

* (fix): move check eariler

* (refactor): correct guard clause

* (chore): remove unnecessary `AttributeError`

* (feat): singleton wrapped as array

* (feat): remove shared dtype casting

* (feat): loop once over `dataframe.items`

* (feat): add `__len__` attribute

* (fix): ensure constructor recieves `pd.Categorical`

* Update xarray/core/extension_array.py

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* Update xarray/core/extension_array.py

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* (fix): drop condition for categorical corrected

* Apply suggestions from code review

* (chore): test `chunk` behavior

* Update xarray/core/variable.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): bring  back error

* (chore): add test for dropping cat for mean

* Update whats-new.rst

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 18, 2024
1 parent 60f3e74 commit 9eb180b
Show file tree
Hide file tree
Showing 16 changed files with 434 additions and 43 deletions.
6 changes: 5 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ New Features
~~~~~~~~~~~~
- New "random" method for converting to and from 360_day calendars (:pull:`8603`).
By `Pascal Bourgault <https://github.com/aulemahal>`_.

- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array
by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`,
for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray`
then, such as broadcasting.
By `Ilan Gold <https://github.com/ilan-gold>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
4 changes: 3 additions & 1 deletion properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from hypothesis import given # isort:skip

numeric_dtypes = st.one_of(
npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes()
npst.unsigned_integer_dtypes(endianness="="),
npst.integer_dtypes(endianness="="),
npst.floating_dtypes(endianness="="),
)

numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt))
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ module = [
"opt_einsum.*",
"pandas.*",
"pooch.*",
"pyarrow.*",
"pydap.*",
"pytest.*",
"scipy.*",
Expand Down
60 changes: 47 additions & 13 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload

import numpy as np
from pandas.api.types import is_extension_array_dtype

# remove once numpy 2.0 is the oldest supported version
try:
Expand Down Expand Up @@ -6852,10 +6853,13 @@ def reduce(
if (
# Some reduction functions (e.g. std, var) need to run on variables
# that don't have the reduce dims: PR5393
not reduce_dims
or not numeric_only
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
not is_extension_array_dtype(var.dtype)
and (
not reduce_dims
or not numeric_only
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
)
):
# prefer to aggregate over axis=None rather than
# axis=(0, 1) if they will be equivalent, because
Expand Down Expand Up @@ -7168,13 +7172,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
)

def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
columns = [k for k in self.variables if k not in self.dims]
columns_in_order = [k for k in self.variables if k not in self.dims]
non_extension_array_columns = [
k
for k in columns_in_order
if not is_extension_array_dtype(self.variables[k].data)
]
extension_array_columns = [
k
for k in columns_in_order
if is_extension_array_dtype(self.variables[k].data)
]
data = [
self._variables[k].set_dims(ordered_dims).values.reshape(-1)
for k in columns
for k in non_extension_array_columns
]
index = self.coords.to_index([*ordered_dims])
return pd.DataFrame(dict(zip(columns, data)), index=index)
broadcasted_df = pd.DataFrame(
dict(zip(non_extension_array_columns, data)), index=index
)
for extension_array_column in extension_array_columns:
extension_array = self.variables[extension_array_column].data.array
index = self[self.variables[extension_array_column].dims[0]].data
extension_array_df = pd.DataFrame(
{extension_array_column: extension_array},
index=self[self.variables[extension_array_column].dims[0]].data,
)
extension_array_df.index.name = self.variables[extension_array_column].dims[
0
]
broadcasted_df = broadcasted_df.join(extension_array_df)
return broadcasted_df[columns_in_order]

def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame:
"""Convert this dataset into a pandas.DataFrame.
Expand Down Expand Up @@ -7321,11 +7349,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
)

# Cast to a NumPy array first, in case the Series is a pandas Extension
# array (which doesn't have a valid NumPy dtype)
# TODO: allow users to control how this casting happens, e.g., by
# forwarding arguments to pandas.Series.to_numpy?
arrays = [(k, np.asarray(v)) for k, v in dataframe.items()]
arrays = []
extension_arrays = []
for k, v in dataframe.items():
if not is_extension_array_dtype(v):
arrays.append((k, np.asarray(v)))
else:
extension_arrays.append((k, v))

indexes: dict[Hashable, Index] = {}
index_vars: dict[Hashable, Variable] = {}
Expand All @@ -7339,6 +7369,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
xr_idx = PandasIndex(lev, dim)
indexes[dim] = xr_idx
index_vars.update(xr_idx.create_variables())
arrays += [(k, np.asarray(v)) for k, v in extension_arrays]
extension_arrays = []
else:
index_name = idx.name if idx.name is not None else "index"
dims = (index_name,)
Expand All @@ -7352,7 +7384,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
obj._set_sparse_data_from_dataframe(idx, arrays, dims)
else:
obj._set_numpy_data_from_dataframe(idx, arrays, dims)
return obj
for name, extension_array in extension_arrays:
obj[name] = (dims, extension_array)
return obj[dataframe.columns] if len(dataframe.columns) else obj

def to_dask_dataframe(
self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False
Expand Down
19 changes: 15 additions & 4 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from numpy import concatenate as _concatenate
from numpy.lib.stride_tricks import sliding_window_view # noqa
from packaging.version import Version
from pandas.api.types import is_extension_array_dtype

from xarray.core import dask_array_ops, dtypes, nputils
from xarray.core.options import OPTIONS
Expand Down Expand Up @@ -156,7 +157,7 @@ def isnull(data):
return full_like(data, dtype=bool, fill_value=False)
else:
# at this point, array should have dtype=object
if isinstance(data, np.ndarray):
if isinstance(data, np.ndarray) or is_extension_array_dtype(data):
return pandas_isnull(data)
else:
# Not reachable yet, but intended for use with other duck array
Expand Down Expand Up @@ -221,9 +222,19 @@ def asarray(data, xp=np):

def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
array_type_cupy = array_type("cupy")
if array_type_cupy and any(
isinstance(x, array_type_cupy) for x in scalars_or_arrays
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
if len(extension_array_types) == len(scalars_or_arrays) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return scalars_or_arrays
raise ValueError(
f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
)
elif array_type_cupy := array_type("cupy") and any( # noqa: F841
isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821
):
import cupy as cp

Expand Down
136 changes: 136 additions & 0 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Callable, Generic

import numpy as np
import pandas as pd
from pandas.api.types import is_extension_array_dtype

from xarray.core.types import DTypeLikeSave, T_ExtensionArray

HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {}


def implements(numpy_function):
"""Register an __array_function__ implementation for MyArray objects."""

def decorator(func):
HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func
return func

return decorator


@implements(np.issubdtype)
def __extension_duck_array__issubdtype(
extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave
) -> bool:
return False # never want a function to think a pandas extension dtype is a subtype of numpy


@implements(np.broadcast_to)
def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple):
if shape[0] == len(arr) and len(shape) == 1:
return arr
raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.")


@implements(np.stack)
def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int):
raise NotImplementedError("Cannot stack 1d-only pandas categorical array.")


@implements(np.concatenate)
def __extension_duck_array__concatenate(
arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None
) -> T_ExtensionArray:
return type(arrays[0])._concat_same_type(arrays)


@implements(np.where)
def __extension_duck_array__where(
condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray
) -> T_ExtensionArray:
if (
isinstance(x, pd.Categorical)
and isinstance(y, pd.Categorical)
and x.dtype != y.dtype
):
x = x.add_categories(set(y.categories).difference(set(x.categories)))
y = y.add_categories(set(x.categories).difference(set(y.categories)))
return pd.Series(x).where(condition, pd.Series(y)).array


class PandasExtensionArray(Generic[T_ExtensionArray]):
array: T_ExtensionArray

def __init__(self, array: T_ExtensionArray):
"""NEP-18 compliant wrapper for pandas extension arrays.
Parameters
----------
array : T_ExtensionArray
The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation.
```
"""
if not isinstance(array, pd.api.extensions.ExtensionArray):
raise TypeError(f"{array} is not an pandas ExtensionArray.")
self.array = array

def __array_function__(self, func, types, args, kwargs):
def replace_duck_with_extension_array(args) -> list:
args_as_list = list(args)
for index, value in enumerate(args_as_list):
if isinstance(value, PandasExtensionArray):
args_as_list[index] = value.array
elif isinstance(
value, tuple
): # should handle more than just tuple? iterable?
args_as_list[index] = tuple(
replace_duck_with_extension_array(value)
)
elif isinstance(value, list):
args_as_list[index] = replace_duck_with_extension_array(value)
return args_as_list

args = tuple(replace_duck_with_extension_array(args))
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
return func(*args, **kwargs)
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs)
if is_extension_array_dtype(res):
return type(self)[type(res)](res)
return res

def __array_ufunc__(ufunc, method, *inputs, **kwargs):
return ufunc(*inputs, **kwargs)

def __repr__(self):
return f"{type(self)}(array={repr(self.array)})"

def __getattr__(self, attr: str) -> object:
return getattr(self.array, attr)

def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
item = self.array[key]
if is_extension_array_dtype(item):
return type(self)(item)
if np.isscalar(item):
return type(self)(type(self.array)([item]))
return item

def __setitem__(self, key, val):
self.array[key] = val

def __eq__(self, other):
if np.isscalar(other):
other = type(self)(type(self.array)([other]))
if isinstance(other, PandasExtensionArray):
return self.array == other.array
return self.array == other

def __ne__(self, other):
return ~(self == other)

def __len__(self):
return len(self.array)
3 changes: 3 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def copy(
# hopefully in the future we can narrow this down more:
T_DuckArray = TypeVar("T_DuckArray", bound=Any, covariant=True)

# For typing pandas extension arrays.
T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray)


ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
VarCompatible = Union["Variable", "ScalarOrArray"]
Expand Down
10 changes: 10 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from pandas.api.types import is_extension_array_dtype

import xarray as xr # only for Dataset and DataArray
from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils
from xarray.core.arithmetic import VariableArithmetic
from xarray.core.common import AbstractArray
from xarray.core.extension_array import PandasExtensionArray
from xarray.core.indexing import (
BasicIndexer,
OuterIndexer,
Expand Down Expand Up @@ -47,6 +49,7 @@
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
indexing.ExplicitlyIndexed,
pd.Index,
pd.api.extensions.ExtensionArray,
)
# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,)
Expand Down Expand Up @@ -184,6 +187,8 @@ def _maybe_wrap_data(data):
"""
if isinstance(data, pd.Index):
return PandasIndexingAdapter(data)
if isinstance(data, pd.api.extensions.ExtensionArray):
return PandasExtensionArray[type(data)](data)
return data


Expand Down Expand Up @@ -2570,6 +2575,11 @@ def chunk( # type: ignore[override]
dask.array.from_array
"""

if is_extension_array_dtype(self):
raise ValueError(
f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first."
)

if from_array_kwargs is None:
from_array_kwargs = {}

Expand Down
10 changes: 5 additions & 5 deletions xarray/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
Generates only those numpy dtypes which xarray can handle.
Use instead of hypothesis.extra.numpy.scalar_dtypes in order to exclude weirder dtypes such as unicode, byte_string, array, or nested dtypes.
Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows.
Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows. Checks only native endianness.
Requires the hypothesis package to be installed.
Expand All @@ -56,10 +56,10 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
# TODO should this be exposed publicly?
# We should at least decide what the set of numpy dtypes that xarray officially supports is.
return (
npst.integer_dtypes()
| npst.unsigned_integer_dtypes()
| npst.floating_dtypes()
| npst.complex_number_dtypes()
npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=")
| npst.complex_number_dtypes(endianness="=")
# | npst.datetime64_dtypes()
# | npst.timedelta64_dtypes()
# | npst.unicode_string_dtypes()
Expand Down
Loading

0 comments on commit 9eb180b

Please sign in to comment.