From 1c67f9eb22c5528266c2748d146429c2c7336e19 Mon Sep 17 00:00:00 2001 From: Yinlei Sun Date: Mon, 10 Apr 2023 16:31:31 +0800 Subject: [PATCH] [Enhancement] Support BoolTensor and LongTensor on Ascend NPU (#1011) --- mmengine/structures/instance_data.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/mmengine/structures/instance_data.py b/mmengine/structures/instance_data.py index 36c60fd155..1ceac9ad24 100644 --- a/mmengine/structures/instance_data.py +++ b/mmengine/structures/instance_data.py @@ -1,16 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. import itertools from collections.abc import Sized -from typing import List, Union +from typing import Any, List, Union import numpy as np import torch +from mmengine.device import get_device from .base_data_element import BaseDataElement -IndexType = Union[str, slice, int, list, torch.LongTensor, - torch.cuda.LongTensor, torch.BoolTensor, - torch.cuda.BoolTensor, np.ndarray] +BoolTypeTensor: Union[Any] +LongTypeTensor: Union[Any] + +if get_device() == 'npu': + BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor] + LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor] +else: + BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] + LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] + +IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor, + BoolTypeTensor, np.ndarray] # Modified from @@ -156,6 +166,7 @@ def __getitem__(self, item: IndexType) -> 'InstanceData': Returns: :obj:`InstanceData`: Corresponding values. """ + assert isinstance(item, IndexType.__args__) if isinstance(item, list): item = np.array(item) if isinstance(item, np.ndarray): @@ -165,9 +176,6 @@ def __getitem__(self, item: IndexType) -> 'InstanceData': # More details in https://github.com/numpy/numpy/issues/9464 item = item.astype(np.int64) if item.dtype == np.int32 else item item = torch.from_numpy(item) - assert isinstance( - item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor, - torch.BoolTensor, torch.cuda.BoolTensor)) if isinstance(item, str): return getattr(self, item) @@ -183,7 +191,7 @@ def __getitem__(self, item: IndexType) -> 'InstanceData': if isinstance(item, torch.Tensor): assert item.dim() == 1, 'Only support to get the' \ ' values along the first dimension.' - if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)): + if isinstance(item, BoolTypeTensor.__args__): assert len(item) == len(self), 'The shape of the ' \ 'input(BoolTensor) ' \ f'{len(item)} ' \ @@ -202,8 +210,7 @@ def __getitem__(self, item: IndexType) -> 'InstanceData': v, (str, list, tuple)) or (hasattr(v, '__getitem__') and hasattr(v, 'cat')): # convert to indexes from BoolTensor - if isinstance(item, - (torch.BoolTensor, torch.cuda.BoolTensor)): + if isinstance(item, BoolTypeTensor.__args__): indexes = torch.nonzero(item).view( -1).cpu().numpy().tolist() else: