Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not attempt to broadcast when global option arithmetic_broadcast=False #8784

Merged
merged 16 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ v2024.03.0 (unreleased)
New Features
~~~~~~~~~~~~

- Do not attempt to broadcast when global option ``arithmetic_broadcast=False``
dcherian marked this conversation as resolved.
Show resolved Hide resolved
(:issue:`6806`, :pull:`8784`).
By `Etienne Schalk <https://github.com/etienneschalk>`_.
- Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`)
By `Anderson Banihirwe <https://github.com/andersy005>`_.

Expand Down
3 changes: 3 additions & 0 deletions xarray/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
]

class T_Options(TypedDict):
arithmetic_broadcast: bool
arithmetic_join: Literal["inner", "outer", "left", "right", "exact"]
cmap_divergent: str | Colormap
cmap_sequential: str | Colormap
Expand All @@ -59,6 +60,7 @@ class T_Options(TypedDict):


OPTIONS: T_Options = {
"arithmetic_broadcast": True,
"arithmetic_join": "inner",
"cmap_divergent": "RdBu_r",
"cmap_sequential": "viridis",
Expand Down Expand Up @@ -92,6 +94,7 @@ def _positive_integer(value: int) -> bool:


_VALIDATORS = {
"arithmetic_broadcast": lambda value: isinstance(value, bool),
"arithmetic_join": _JOIN_OPTIONS.__contains__,
"display_max_rows": _positive_integer,
"display_values_threshold": _positive_integer,
Expand Down
6 changes: 6 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2864,6 +2864,12 @@ def broadcast_variables(*variables: Variable) -> tuple[Variable, ...]:


def _broadcast_compat_data(self, other):
if not OPTIONS["arithmetic_broadcast"]:
if (isinstance(other, Variable) and self.dims != other.dims) or (
isinstance(other, np.ndarray) and self.ndim != other.ndim
etienneschalk marked this conversation as resolved.
Show resolved Hide resolved
):
raise ValueError("arithmetic broadcast is disabled via global option")
etienneschalk marked this conversation as resolved.
Show resolved Hide resolved

if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]):
# `other` satisfies the necessary Variable API for broadcast_variables
new_self, new_other = _broadcast_compat_variables(self, other)
Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ def _importorskip(
has_pynio, requires_pynio = _importorskip("Nio")
has_cftime, requires_cftime = _importorskip("cftime")
has_dask, requires_dask = _importorskip("dask")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The current Dask DataFrame implementation is deprecated.",
category=DeprecationWarning,
)
has_dask_expr, requires_dask_expr = _importorskip("dask_expr")
has_bottleneck, requires_bottleneck = _importorskip("bottleneck")
has_rasterio, requires_rasterio = _importorskip("rasterio")
has_zarr, requires_zarr = _importorskip("zarr")
Expand Down
39 changes: 39 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
requires_bottleneck,
requires_cupy,
requires_dask,
requires_dask_expr,
requires_iris,
requires_numexpr,
requires_pint,
Expand Down Expand Up @@ -3203,6 +3204,43 @@ def test_align_str_dtype(self) -> None:
assert_identical(expected_b, actual_b)
assert expected_b.x.dtype == actual_b.x.dtype

def test_broadcast_on_vs_off_global_option_different_dims(self) -> None:
xda_1 = xr.DataArray([1], dims="x1")
xda_2 = xr.DataArray([1], dims="x2")

with xr.set_options(arithmetic_broadcast=True):
expected_xda = xr.DataArray([[1.0]], dims=("x1", "x2"))
actual_xda = xda_1 / xda_2
assert_identical(actual_xda, expected_xda)

with xr.set_options(arithmetic_broadcast=False):
with pytest.raises(
ValueError,
match=re.escape("arithmetic broadcast is disabled via global option"),
):
xda_1 / xda_2

@pytest.mark.parametrize("arithmetic_broadcast", [True, False])
def test_broadcast_on_vs_off_global_option_same_dims(
self, arithmetic_broadcast: bool
) -> None:
# Ensure that no error is raised when arithmetic broadcasting is disabled,
# when broadcasting is not needed. The two DataArrays have the same
# dimensions of the same size.
xda_1 = xr.DataArray([1], dims="x")
xda_2 = xr.DataArray([1], dims="x")
expected_xda = xr.DataArray([2.0], dims=("x",))

with xr.set_options(arithmetic_broadcast=arithmetic_broadcast):
actual_xda = xda_1 + xda_2
assert_identical(actual_xda, expected_xda)

actual_xda = xda_1 + np.array([1.0])
assert_identical(actual_xda, expected_xda)

actual_xda = np.array([1.0]) + xda_1
dcherian marked this conversation as resolved.
Show resolved Hide resolved
assert_identical(actual_xda, expected_xda)

def test_broadcast_arrays(self) -> None:
x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x")
y = DataArray([1, 2], coords=[("b", [3, 4])], name="y")
Expand Down Expand Up @@ -3381,6 +3419,7 @@ def test_to_dataframe_0length(self) -> None:
assert len(actual) == 0
assert_array_equal(actual.index.names, list("ABC"))

@requires_dask_expr
@requires_dask
def test_to_dask_dataframe(self) -> None:
arr_np = np.arange(3 * 4).reshape(3, 4)
Expand Down
Loading