diff --git a/mmengine/structures/instance_data.py b/mmengine/structures/instance_data.py index 36c60fd155..ae77633a5b 100644 --- a/mmengine/structures/instance_data.py +++ b/mmengine/structures/instance_data.py @@ -1,16 +1,28 @@ # Copyright (c) OpenMMLab. All rights reserved. import itertools from collections.abc import Sized -from typing import List, Union +from typing import Any, List, Union import numpy as np import torch +from mmengine.device import get_device from .base_data_element import BaseDataElement -IndexType = Union[str, slice, int, list, torch.LongTensor, - torch.cuda.LongTensor, torch.BoolTensor, - torch.cuda.BoolTensor, np.ndarray] +BoolTypeTensor: Union[Any] +LongTypeTensor: Union[Any] + +if get_device() == 'npu': + BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor, + torch.npu.BoolTensor] + LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor, + torch.npu.LongTensor] +else: + BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] + LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] + +IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, + np.ndarray] # Modified from @@ -165,9 +177,8 @@ def __getitem__(self, item: IndexType) -> 'InstanceData': # More details in https://github.com/numpy/numpy/issues/9464 item = item.astype(np.int64) if item.dtype == np.int32 else item item = torch.from_numpy(item) - assert isinstance( - item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor, - torch.BoolTensor, torch.cuda.BoolTensor)) + assert isinstance(item, (str, slice, int, LongTypeTensor.__args__, + BoolTypeTensor.__args__)) if isinstance(item, str): return getattr(self, item) @@ -183,7 +194,7 @@ def __getitem__(self, item: IndexType) -> 'InstanceData': if isinstance(item, torch.Tensor): assert item.dim() == 1, 'Only support to get the' \ ' values along the first dimension.' - if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)): + if isinstance(item, BoolTypeTensor.__args__): assert len(item) == len(self), 'The shape of the ' \ 'input(BoolTensor) ' \ f'{len(item)} ' \ @@ -202,8 +213,7 @@ def __getitem__(self, item: IndexType) -> 'InstanceData': v, (str, list, tuple)) or (hasattr(v, '__getitem__') and hasattr(v, 'cat')): # convert to indexes from BoolTensor - if isinstance(item, - (torch.BoolTensor, torch.cuda.BoolTensor)): + if isinstance(item, BoolTypeTensor.__args__): indexes = torch.nonzero(item).view( -1).cpu().numpy().tolist() else: