Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
momo609 committed Sep 14, 2023
1 parent a01c37c commit 13be022
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions mmcv/ops/csrc/common/pytorch_npu_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@
REGISTER_DEVICE_IMPL(key, PrivateUse1, value)
#endif

#if MMCV_WITH_XLA
#ifdef MMCV_WITH_XLA
#define CHECK_NPU(x) \
TORCH_CHECK( \
x.device().type() == at::kXLA, #x " must be a NPU tensor")
#if MMCV_WITH_KPRIVATE
x.device().type() == at::kXLA, \
#x " must be a NPU tensor")
#else
#define CHECK_NPU(x)
TORCH_CHECK( \
x.device().type() == at::kPrivateUse1, #x " must be a NPU tensor")
x.device().type() == at::kPrivateUse1, \
#x " must be a NPU tensor")
#endif
#endif // PYTORCH_NPU_HELPER_HPP_

0 comments on commit 13be022

Please sign in to comment.