Skip to content

Commit

Permalink
Do not attempt to broadcast when global option ``arithmetic_broadcast…
Browse files Browse the repository at this point in the history
…=False`` (#8784)

* increase plot size

* added old tests

* Keep relevant test

* what's new

* PR comment

* remove unnecessary (?) check

* unnecessary line removal

* removal of variable reassignment to avoid type issue

* Update xarray/core/variable.py

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* Update xarray/core/variable.py

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update tests

* what's new

* Update doc/whats-new.rst

---------

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 13, 2024
1 parent 14fe7e0 commit 11f89ec
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 0 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ v2024.03.0 (unreleased)
New Features
~~~~~~~~~~~~

- Do not broadcast in arithmetic operations when global option ``arithmetic_broadcast=False``
(:issue:`6806`, :pull:`8784`).
By `Etienne Schalk <https://github.com/etienneschalk>`_ and `Deepak Cherian <https://github.com/dcherian>`_.
- Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`)
By `Anderson Banihirwe <https://github.com/andersy005>`_.

Expand Down
3 changes: 3 additions & 0 deletions xarray/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
]

class T_Options(TypedDict):
arithmetic_broadcast: bool
arithmetic_join: Literal["inner", "outer", "left", "right", "exact"]
cmap_divergent: str | Colormap
cmap_sequential: str | Colormap
Expand All @@ -59,6 +60,7 @@ class T_Options(TypedDict):


OPTIONS: T_Options = {
"arithmetic_broadcast": True,
"arithmetic_join": "inner",
"cmap_divergent": "RdBu_r",
"cmap_sequential": "viridis",
Expand Down Expand Up @@ -92,6 +94,7 @@ def _positive_integer(value: int) -> bool:


_VALIDATORS = {
"arithmetic_broadcast": lambda value: isinstance(value, bool),
"arithmetic_join": _JOIN_OPTIONS.__contains__,
"display_max_rows": _positive_integer,
"display_values_threshold": _positive_integer,
Expand Down
10 changes: 10 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2871,6 +2871,16 @@ def broadcast_variables(*variables: Variable) -> tuple[Variable, ...]:


def _broadcast_compat_data(self, other):
if not OPTIONS["arithmetic_broadcast"]:
if (isinstance(other, Variable) and self.dims != other.dims) or (
is_duck_array(other) and self.ndim != other.ndim
):
raise ValueError(
"Broadcasting is necessary but automatic broadcasting is disabled via "
"global option `'arithmetic_broadcast'`. "
"Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting."
)

if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]):
# `other` satisfies the necessary Variable API for broadcast_variables
new_self, new_other = _broadcast_compat_variables(self, other)
Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ def _importorskip(
has_pynio, requires_pynio = _importorskip("Nio")
has_cftime, requires_cftime = _importorskip("cftime")
has_dask, requires_dask = _importorskip("dask")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The current Dask DataFrame implementation is deprecated.",
category=DeprecationWarning,
)
has_dask_expr, requires_dask_expr = _importorskip("dask_expr")
has_bottleneck, requires_bottleneck = _importorskip("bottleneck")
has_rasterio, requires_rasterio = _importorskip("rasterio")
has_zarr, requires_zarr = _importorskip("zarr")
Expand Down
38 changes: 38 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
requires_bottleneck,
requires_cupy,
requires_dask,
requires_dask_expr,
requires_iris,
requires_numexpr,
requires_pint,
Expand Down Expand Up @@ -3203,6 +3204,42 @@ def test_align_str_dtype(self) -> None:
assert_identical(expected_b, actual_b)
assert expected_b.x.dtype == actual_b.x.dtype

def test_broadcast_on_vs_off_global_option_different_dims(self) -> None:
xda_1 = xr.DataArray([1], dims="x1")
xda_2 = xr.DataArray([1], dims="x2")

with xr.set_options(arithmetic_broadcast=True):
expected_xda = xr.DataArray([[1.0]], dims=("x1", "x2"))
actual_xda = xda_1 / xda_2
assert_identical(actual_xda, expected_xda)

with xr.set_options(arithmetic_broadcast=False):
with pytest.raises(
ValueError,
match=re.escape(
"Broadcasting is necessary but automatic broadcasting is disabled via "
"global option `'arithmetic_broadcast'`. "
"Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting."
),
):
xda_1 / xda_2

@pytest.mark.parametrize("arithmetic_broadcast", [True, False])
def test_broadcast_on_vs_off_global_option_same_dims(
self, arithmetic_broadcast: bool
) -> None:
# Ensure that no error is raised when arithmetic broadcasting is disabled,
# when broadcasting is not needed. The two DataArrays have the same
# dimensions of the same size.
xda_1 = xr.DataArray([1], dims="x")
xda_2 = xr.DataArray([1], dims="x")
expected_xda = xr.DataArray([2.0], dims=("x",))

with xr.set_options(arithmetic_broadcast=arithmetic_broadcast):
assert_identical(xda_1 + xda_2, expected_xda)
assert_identical(xda_1 + np.array([1.0]), expected_xda)
assert_identical(np.array([1.0]) + xda_1, expected_xda)

def test_broadcast_arrays(self) -> None:
x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x")
y = DataArray([1, 2], coords=[("b", [3, 4])], name="y")
Expand Down Expand Up @@ -3381,6 +3418,7 @@ def test_to_dataframe_0length(self) -> None:
assert len(actual) == 0
assert_array_equal(actual.index.names, list("ABC"))

@requires_dask_expr
@requires_dask
def test_to_dask_dataframe(self) -> None:
arr_np = np.arange(3 * 4).reshape(3, 4)
Expand Down

0 comments on commit 11f89ec

Please sign in to comment.