Skip to content

Commit

Permalink
Use opt_einsum by default if installed. (#8373)
Browse files Browse the repository at this point in the history
* Use `opt_einsum` by default if installed.

Closes #7764
Closes #8017

* docstring update

* _

* _

Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>

* Update xarray/core/computation.py

Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>

* Fix docs?

* Add use_opt_einsum option.

* mypy ignore

* one more test ignore

* Disable navigation_with_keys

* remove intersphinx

* One more skip

---------

Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>
  • Loading branch information
dcherian and max-sixty committed Oct 28, 2023
1 parent bb489fa commit d40609a
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 15 deletions.
3 changes: 2 additions & 1 deletion ci/install-upstream-wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ python -m pip install \
git+https://github.com/intake/filesystem_spec \
git+https://github.com/SciTools/nc-time-axis \
git+https://github.com/xarray-contrib/flox \
git+https://github.com/h5netcdf/h5netcdf
git+https://github.com/h5netcdf/h5netcdf \
git+https://github.com/dgasmith/opt_einsum
1 change: 1 addition & 0 deletions ci/requirements/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies:
- numbagg
- numexpr
- numpy
- opt_einsum
- packaging
- pandas
- pint<0.21
Expand Down
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@
use_repository_button=True,
use_issues_button=True,
home_page_in_toc=False,
navigation_with_keys=False,
extra_footer="""<p>Xarray is a fiscally sponsored project of <a href="https://numfocus.org">NumFOCUS</a>,
a nonprofit dedicated to supporting the open-source scientific computing community.<br>
Theme by the <a href="https://ebp.jupyterbook.org">Executable Book Project</a></p>""",
Expand Down Expand Up @@ -327,6 +328,7 @@
"sparse": ("https://sparse.pydata.org/en/latest/", None),
"cubed": ("https://tom-e-white.com/cubed/", None),
"datatree": ("https://xarray-datatree.readthedocs.io/en/latest/", None),
# "opt_einsum": ("https://dgasmith.github.io/opt_einsum/", None),
}


Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ v2023.10.2 (unreleased)
New Features
~~~~~~~~~~~~

- Use `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ for :py:func:`xarray.dot` by default if installed.
By `Deepak Cherian <https://github.com/dcherian>`_. (:issue:`7764`, :pull:`8373`).

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ source-code = "https://github.com/pydata/xarray"
dask = "xarray.core.daskmanager:DaskManager"

[project.optional-dependencies]
accel = ["scipy", "bottleneck", "numbagg", "flox"]
accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"]
complete = ["xarray[accel,io,parallel,viz]"]
io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"]
parallel = ["dask[complete]"]
Expand Down Expand Up @@ -106,6 +106,7 @@ module = [
"numbagg.*",
"netCDF4.*",
"netcdftime.*",
"opt_einsum.*",
"pandas.*",
"pooch.*",
"PseudoNetCDF.*",
Expand Down
19 changes: 15 additions & 4 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,8 +1690,8 @@ def dot(
dims: Dims = None,
**kwargs: Any,
):
"""Generalized dot product for xarray objects. Like np.einsum, but
provides a simpler interface based on array dimensions.
"""Generalized dot product for xarray objects. Like ``np.einsum``, but
provides a simpler interface based on array dimension names.
Parameters
----------
Expand All @@ -1701,13 +1701,24 @@ def dot(
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
If not specified, then all the common dimensions are summed over.
**kwargs : dict
Additional keyword arguments passed to numpy.einsum or
dask.array.einsum
Additional keyword arguments passed to ``numpy.einsum`` or
``dask.array.einsum``
Returns
-------
DataArray
See Also
--------
numpy.einsum
dask.array.einsum
opt_einsum.contract
Notes
-----
We recommend installing the optional ``opt_einsum`` package, or alternatively passing ``optimize=True``,
which is passed through to ``np.einsum``, and works for most array backends.
Examples
--------
>>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=["a", "b"])
Expand Down
12 changes: 11 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from numpy import any as array_any # noqa
from numpy import ( # noqa
around, # noqa
einsum,
gradient,
isclose,
isin,
Expand Down Expand Up @@ -48,6 +47,17 @@ def get_array_namespace(x):
return np


def einsum(*args, **kwargs):
from xarray.core.options import OPTIONS

if OPTIONS["use_opt_einsum"] and module_available("opt_einsum"):
import opt_einsum

return opt_einsum.contract(*args, **kwargs)
else:
return np.einsum(*args, **kwargs)


def _dask_or_eager_func(
name,
eager_module=np,
Expand Down
6 changes: 6 additions & 0 deletions xarray/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"warn_for_unclosed_files",
"use_bottleneck",
"use_numbagg",
"use_opt_einsum",
"use_flox",
]

Expand All @@ -52,6 +53,7 @@ class T_Options(TypedDict):
use_bottleneck: bool
use_flox: bool
use_numbagg: bool
use_opt_einsum: bool


OPTIONS: T_Options = {
Expand All @@ -75,6 +77,7 @@ class T_Options(TypedDict):
"use_bottleneck": True,
"use_flox": True,
"use_numbagg": True,
"use_opt_einsum": True,
}

_JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"])
Expand Down Expand Up @@ -102,6 +105,7 @@ def _positive_integer(value: int) -> bool:
"keep_attrs": lambda choice: choice in [True, False, "default"],
"use_bottleneck": lambda value: isinstance(value, bool),
"use_numbagg": lambda value: isinstance(value, bool),
"use_opt_einsum": lambda value: isinstance(value, bool),
"use_flox": lambda value: isinstance(value, bool),
"warn_for_unclosed_files": lambda value: isinstance(value, bool),
}
Expand Down Expand Up @@ -237,6 +241,8 @@ class set_options:
use_numbagg : bool, default: True
Whether to use ``numbagg`` to accelerate reductions.
Takes precedence over ``use_bottleneck`` when both are True.
use_opt_einsum : bool, default: True
Whether to use ``opt_einsum`` to accelerate dot products.
warn_for_unclosed_files : bool, default: False
Whether or not to issue a warning when unclosed files are
deallocated. This is mostly useful for debugging.
Expand Down
19 changes: 11 additions & 8 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,10 +1502,11 @@ def test_dot_dataarray(dtype):
data_array = xr.DataArray(data=array1, dims=("x", "y"))
other = xr.DataArray(data=array2, dims=("y", "z"))

expected = attach_units(
xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m}
)
actual = xr.dot(data_array, other)
with xr.set_options(use_opt_einsum=False):
expected = attach_units(
xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m}
)
actual = xr.dot(data_array, other)

assert_units_equal(expected, actual)
assert_identical(expected, actual)
Expand Down Expand Up @@ -2465,8 +2466,9 @@ def test_binary_operations(self, func, dtype):
data_array = xr.DataArray(data=array)

units = extract_units(func(array))
expected = attach_units(func(strip_units(data_array)), units)
actual = func(data_array)
with xr.set_options(use_opt_einsum=False):
expected = attach_units(func(strip_units(data_array)), units)
actual = func(data_array)

assert_units_equal(expected, actual)
assert_identical(expected, actual)
Expand Down Expand Up @@ -3829,8 +3831,9 @@ def test_computation(self, func, variant, dtype):
if not isinstance(func, (function, method)):
units.update(extract_units(func(array.reshape(-1))))

expected = attach_units(func(strip_units(data_array)), units)
actual = func(data_array)
with xr.set_options(use_opt_einsum=False):
expected = attach_units(func(strip_units(data_array)), units)
actual = func(data_array)

assert_units_equal(expected, actual)
assert_identical(expected, actual)
Expand Down

0 comments on commit d40609a

Please sign in to comment.