Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cast numpy scalars to arrays in as_compatible_data #9403

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

keewis
Copy link
Collaborator

@keewis keewis commented Aug 27, 2024

As mentioned in #9399, numpy recently added __array_namespace__ to scalars, which will cause Variable to wrap the scalar. This also casts numpy scalars to numpy.ndarray if it ever reaches as_compatible_data.

This will most likely also pop up in NamedArray / xarray.namedarray.from_array, so we might have to figure out what to do there. cc @andersy005 and @Illviljan for awareness

@TomNicholas TomNicholas added the topic-arrays related to flexible array support label Aug 27, 2024
@TomNicholas
Copy link
Contributor

This will most likely also pop up in NamedArray / xarray.namedarray.from_array, so we might have to figure out what to do there.

I think the equivalent check is here

if isinstance(data, _arrayfunction_or_api):

This does a runtime_checkable check against the _arrayfunction_or_api, protocol type. That protocol includes __array_namespace__ but also includes shape, so if the numpy scalar doesn't have a shape attribute then maybe this will already work correctly?

@keewis
Copy link
Collaborator Author

keewis commented Aug 27, 2024

numpy scalars do have a shape (and ndim / dtype), just like 0d arrays this is (). It feels like numpy scalars have become something like a combination of arrays and python scalars.

@Illviljan
Copy link
Contributor

Returning scalars is not supposed to happen if the array api is respected. Hopefully numpy considers this a a bug too.

import numpy as np
import array_api_strict as xps


a = xps.asarray([1,2.0])
xps.mean(a)
Out[21]: Array(1.5, dtype=array_api_strict.float64)


a = np.asarray([1,2.0])
np.mean(a)
Out[23]: np.float64(1.5)

@keewis
Copy link
Collaborator Author

keewis commented Aug 27, 2024

Returning scalars is not supposed to happen if the array api is respected.

Indeed. However, that is only if you read it as "only arrays can be array API compliant". It appears that numpy went the other way and made numpy scalars behave like 0d arrays, thus becoming array API compliant. In other words, numpy scalars now define ndim, dtype, and shape, as well as __array_namespace__ (and any other required properties / methods I might have forgotten).

@TomNicholas
Copy link
Contributor

thus becoming array API compliant

It surprised me that this is even allowed within the array API. For example

In [1]: import numpy as np

In [2]: s = np.float64(4.1)

In [3]: result = np.broadcast_to(s, (2, 2))

In [4]: result
Out[4]: 
array([[4.1, 4.1],
       [4.1, 4.1]])

In [5]: type(result)
Out[5]: numpy.ndarray

In [6]: type(s)
Out[6]: numpy.float64

means that in the broadcast_to specification in the array API standard then the class represented by the array input type hint going in is different from the class represented by the returned array type hint. BUt it turns out that is actually allowed!

I wonder if there are other places in xarray where we assume that for an array-API-compliant duck type, the type going into an array API function/method is going to be the same as the type coming out (because apparently it doesn't have to be).

xarray/tests/test_variable.py Outdated Show resolved Hide resolved
@keewis
Copy link
Collaborator Author

keewis commented Sep 1, 2024

@Illviljan, should I try to apply something like this to namedarray, as well? Or do you want to figure out a different way to exclude numpy scalars in namedarray? In any case, feel free to push to this PR if so.

Otherwise I believe this should be ready for a final review?

@@ -311,7 +311,7 @@ def as_compatible_data(
else:
data = np.asarray(data)

if not isinstance(data, np.ndarray) and (
if not isinstance(data, np.ndarray | np.generic) and (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is worth adding a comment, noting that we want to cast numpy scalars to arrays.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-upstream Run upstream CI topic-arrays related to flexible array support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Variable may contain numpy scalars with numpy>=2.1
5 participants