diff --git a/autograd/numpy/numpy_wrapper.py b/autograd/numpy/numpy_wrapper.py index 85bebade..fa7c5378 100644 --- a/autograd/numpy/numpy_wrapper.py +++ b/autograd/numpy/numpy_wrapper.py @@ -4,7 +4,11 @@ from autograd.extend import primitive, notrace_primitive import numpy as _np import autograd.builtins as builtins -from numpy.core.einsumfunc import _parse_einsum_input + +if _np.lib.NumpyVersion(_np.__version__) >= '2.0.0': + from numpy._core.einsumfunc import _parse_einsum_input +else: + from numpy.core.einsumfunc import _parse_einsum_input notrace_functions = [ _np.ndim, _np.shape, _np.iscomplexobj, _np.result_type