Skip to content

Commit

Permalink
fix build error in low arch (#44391)
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardWooSJTU committed Jul 18, 2022
1 parent dd0a07f commit 08cada9
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ __global__ void ElementwiseMask(const T* a,
const T* b,
T* res,
int num_elements) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_elements) return;
const T zero = 0;
res[tid] = b[tid] >= zero ? a[tid] : zero;
#endif
}

template <typename T>
Expand Down Expand Up @@ -121,6 +123,7 @@ __global__ void ReduceSum2(
template <>
__global__ void ReduceSum2<half>(
const half* src, half* dst, int bsz, int nb_head, int max_seq_len) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len);
Expand Down Expand Up @@ -152,6 +155,7 @@ __global__ void ReduceSum2<half>(
static_cast<size_t>(bsz * max_seq_len),
static_cast<platform::float16>(res_half[0]));
}
#endif
}

template <typename T>
Expand Down

0 comments on commit 08cada9

Please sign in to comment.