Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow callables to .drop_vars #8511

Merged
merged 6 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3041,16 +3041,17 @@ def T(self) -> Self:

def drop_vars(
self,
names: Hashable | Iterable[Hashable],
names: Hashable | Iterable[Hashable] | Callable,
*,
errors: ErrorOptions = "raise",
) -> Self:
"""Returns an array with dropped variables.

Parameters
----------
names : Hashable or iterable of Hashable
Name(s) of variables to drop.
names : Hashable or iterable of Hashable or Callable
Name(s) of variables to drop. If a Callable, this object is passed as its
only argument and its result is used.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', raises a ValueError error if any of the variable
passed are not in the dataset. If 'ignore', any given names that are in the
Expand Down Expand Up @@ -3100,6 +3101,14 @@ def drop_vars(
[ 6, 7, 8],
[ 9, 10, 11]])
Dimensions without coordinates: x, y

>>> da.drop_vars(lambda x: x.coords)
<xarray.DataArray (x: 4, y: 3)>
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
Dimensions without coordinates: x, y
"""
ds = self._to_temp_dataset().drop_vars(names, errors=errors)
return self._from_temp_dataset(ds)
Expand Down Expand Up @@ -6328,7 +6337,7 @@ def curvefit(
... param="time_constant"
... ) # doctest: +NUMBER
<xarray.DataArray 'curvefit_coefficients' (x: 3)>
array([1.0569203, 1.7354963, 2.9421577])
array([1.05692035, 1.73549638, 2.9421577 ])
Coordinates:
* x (x) int64 0 1 2
param <U13 'time_constant'
Expand Down
48 changes: 31 additions & 17 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5723,16 +5723,17 @@ def _assert_all_in_dataset(

def drop_vars(
self,
names: Hashable | Iterable[Hashable],
names: Hashable | Iterable[Hashable] | Callable,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
names: Hashable | Iterable[Hashable] | Callable,
names: Hashable | Iterable[Hashable] | Callable[[Self], Hashable | Iterable[Hashable]],

You could try this :)

And not to mention the old str | Iterable[Hashable] story. But the code below seems fine with the current definition, so we can leave it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I also did some str | Iterable[Hashable] tithing

*,
errors: ErrorOptions = "raise",
) -> Self:
"""Drop variables from this dataset.

Parameters
----------
names : hashable or iterable of hashable
Name(s) of variables to drop.
names : Hashable or iterable of Hashable or Callable
Name(s) of variables to drop. If a Callable, this object is passed as its
only argument and its result is used.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', raises a ValueError error if any of the variable
passed are not in the dataset. If 'ignore', any given names that are in the
Expand Down Expand Up @@ -5774,7 +5775,7 @@ def drop_vars(
humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8

# Drop the 'humidity' variable
Drop the 'humidity' variable

>>> dataset.drop_vars(["humidity"])
<xarray.Dataset>
Expand All @@ -5787,7 +5788,7 @@ def drop_vars(
temperature (time, latitude, longitude) float64 25.5 26.3 27.1 28.0
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8

# Drop the 'humidity', 'temperature' variables
Drop the 'humidity', 'temperature' variables

>>> dataset.drop_vars(["humidity", "temperature"])
<xarray.Dataset>
Expand All @@ -5799,7 +5800,18 @@ def drop_vars(
Data variables:
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8

# Attempt to drop non-existent variable with errors="ignore"
Drop all indexes

>>> dataset.drop_vars(lambda x: x.indexes)
<xarray.Dataset>
Dimensions: (time: 1, latitude: 2, longitude: 2)
Dimensions without coordinates: time, latitude, longitude
Data variables:
temperature (time, latitude, longitude) float64 25.5 26.3 27.1 28.0
humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8

Attempt to drop non-existent variable with errors="ignore"

>>> dataset.drop_vars(["pressure"], errors="ignore")
<xarray.Dataset>
Expand All @@ -5813,7 +5825,7 @@ def drop_vars(
humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6
wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8

# Attempt to drop non-existent variable with errors="raise"
Attempt to drop non-existent variable with errors="raise"

>>> dataset.drop_vars(["pressure"], errors="raise")
Traceback (most recent call last):
Expand All @@ -5834,35 +5846,37 @@ def drop_vars(

"""
# the Iterable check is required for mypy
if callable(names):
names = names(self)
if is_scalar(names) or not isinstance(names, Iterable):
names = {names}
names_set = {names}
else:
names = set(names)
names_set = set(names)
if errors == "raise":
self._assert_all_in_dataset(names)
self._assert_all_in_dataset(names_set)

# GH6505
other_names = set()
for var in names:
for var in names_set:
maybe_midx = self._indexes.get(var, None)
if isinstance(maybe_midx, PandasMultiIndex):
idx_coord_names = set(maybe_midx.index.names + [maybe_midx.dim])
idx_other_names = idx_coord_names - set(names)
idx_coord_names = set(maybe_midx.index.names_set + [maybe_midx.dim])
idx_other_names = idx_coord_names - set(names_set)
other_names.update(idx_other_names)
if other_names:
names |= set(other_names)
names_set |= set(other_names)
warnings.warn(
f"Deleting a single level of a MultiIndex is deprecated. Previously, this deleted all levels of a MultiIndex. "
f"Please also drop the following variables: {other_names!r} to avoid an error in the future.",
DeprecationWarning,
stacklevel=2,
)

assert_no_index_corrupted(self.xindexes, names)
assert_no_index_corrupted(self.xindexes, names_set)

variables = {k: v for k, v in self._variables.items() if k not in names}
variables = {k: v for k, v in self._variables.items() if k not in names_set}
coord_names = {k for k in self._coord_names if k in variables}
indexes = {k: v for k, v in self._indexes.items() if k not in names}
indexes = {k: v for k, v in self._indexes.items() if k not in names_set}
return self._replace_with_new_dims(
variables, coord_names=coord_names, indexes=indexes
)
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2652,16 +2652,24 @@
actual = renamed.drop_vars("foo", errors="ignore")
assert_identical(actual, renamed)

def test_drop_vars_callable(self) -> None:
A = DataArray(
np.random.randn(2, 3), dims=["x", "y"], coords={"x": [1, 2], "y": [3, 4, 5]}
)
expected = A.drop_vars(["x", "y"])
actual = A.drop_vars(lambda x: x.indexes)
assert_identical(expected, actual)

def test_drop_multiindex_level(self) -> None:
# GH6505
expected = self.mda.drop_vars(["x", "level_1", "level_2"])

Check failure on line 2665 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9 bare-minimum

TestDataArray.test_drop_multiindex_level AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2665 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 all-but-dask

TestDataArray.test_drop_multiindex_level AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2665 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9 min-all-deps

TestDataArray.test_drop_multiindex_level AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2665 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9

TestDataArray.test_drop_multiindex_level AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2665 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 flaky

TestDataArray.test_drop_multiindex_level AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2665 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.11

TestDataArray.test_drop_multiindex_level AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2665 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.11

TestDataArray.test_drop_multiindex_level AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2665 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.9

TestDataArray.test_drop_multiindex_level AttributeError: 'MultiIndex' object has no attribute 'names_set'
with pytest.warns(DeprecationWarning):
actual = self.mda.drop_vars("level_1")
assert_identical(expected, actual)

def test_drop_all_multiindex_levels(self) -> None:
dim_levels = ["x", "level_1", "level_2"]
actual = self.mda.drop_vars(dim_levels)

Check failure on line 2672 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9 bare-minimum

TestDataArray.test_drop_all_multiindex_levels AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2672 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 all-but-dask

TestDataArray.test_drop_all_multiindex_levels AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2672 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9 min-all-deps

TestDataArray.test_drop_all_multiindex_levels AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2672 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9

TestDataArray.test_drop_all_multiindex_levels AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2672 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 flaky

TestDataArray.test_drop_all_multiindex_levels AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2672 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.11

TestDataArray.test_drop_all_multiindex_levels AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2672 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.11

TestDataArray.test_drop_all_multiindex_levels AttributeError: 'MultiIndex' object has no attribute 'names_set'

Check failure on line 2672 in xarray/tests/test_dataarray.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.9

TestDataArray.test_drop_all_multiindex_levels AttributeError: 'MultiIndex' object has no attribute 'names_set'
# no error, multi-index dropped
for key in dim_levels:
assert key not in actual.xindexes
Expand Down
Loading