Skip to content

Commit

Permalink
[Fix] Using PyTorch WARP_SHFL_DOWN macro for half support (#2843)
Browse files Browse the repository at this point in the history
  • Loading branch information
zstreet87 committed Sep 3, 2023
1 parent 6e9ee26 commit c8a9ae7
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mmcv/ops/csrc/common/cuda/carafe_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#ifndef CARAFE_CUDA_KERNEL_CUH
#define CARAFE_CUDA_KERNEL_CUH

#include <ATen/cuda/DeviceUtils.cuh>

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
Expand Down Expand Up @@ -56,7 +58,8 @@ template <>
__device__ __forceinline__ phalf warpReduceSum(phalf val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
#ifdef MMCV_WITH_HIP
__PHALF(val) += __shfl_down(val, offset);
// Using PyTorch's macro for half support
__PHALF(val) += WARP_SHFL_DOWN(val, offset);
#else
__PHALF(val) +=
__shfl_down_sync(FULL_MASK, __PHALF(val).operator __half(), offset);
Expand Down

0 comments on commit c8a9ae7

Please sign in to comment.