Skip to content

Commit

Permalink
fix fallback isdtype method (#9250)
Browse files Browse the repository at this point in the history
  • Loading branch information
headtr1ck committed Jul 17, 2024
1 parent 71fce9b commit 5d9d984
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions xarray/core/npcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations

from typing import Any

try:
# requires numpy>=2.0
from numpy import isdtype # type: ignore[attr-defined,unused-ignore]
except ImportError:
import numpy as np
from numpy.typing import DTypeLike

dtype_kinds = {
kind_mapping = {
"bool": np.bool_,
"signed integer": np.signedinteger,
"unsigned integer": np.unsignedinteger,
Expand All @@ -45,16 +49,25 @@
"numeric": np.number,
}

def isdtype(dtype, kind):
def isdtype(
dtype: np.dtype[Any] | type[Any], kind: DTypeLike | tuple[DTypeLike, ...]
) -> bool:
kinds = kind if isinstance(kind, tuple) else (kind,)
str_kinds = {k for k in kinds if isinstance(k, str)}
type_kinds = {k.type for k in kinds if isinstance(k, np.dtype)}

unknown_dtypes = [kind for kind in kinds if kind not in dtype_kinds]
if unknown_dtypes:
raise ValueError(f"unknown dtype kinds: {unknown_dtypes}")
if unknown_kind_types := set(kinds) - str_kinds - type_kinds:
raise TypeError(
f"kind must be str, np.dtype or a tuple of these, got {unknown_kind_types}"
)
if unknown_kinds := {k for k in str_kinds if k not in kind_mapping}:
raise ValueError(
f"unknown kind: {unknown_kinds}, must be a np.dtype or one of {list(kind_mapping)}"
)

# verified the dtypes already, no need to check again
translated_kinds = [dtype_kinds[kind] for kind in kinds]
translated_kinds = {kind_mapping[k] for k in str_kinds} | type_kinds
if isinstance(dtype, np.generic):
return any(isinstance(dtype, kind) for kind in translated_kinds)
return isinstance(dtype, translated_kinds)
else:
return any(np.issubdtype(dtype, kind) for kind in translated_kinds)
return any(np.issubdtype(dtype, k) for k in translated_kinds)

0 comments on commit 5d9d984

Please sign in to comment.