Skip to content

Commit

Permalink
cherry-pick PaddlePaddle#46767
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Oct 19, 2022
1 parent 247ef47 commit a053932
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 30 deletions.
73 changes: 45 additions & 28 deletions paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,26 +128,27 @@ __device__ __forceinline__ void warp_reduce_upper_tri(T* sum) {
template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,
T* dst,
int batch_count,
int key_seq_len) {
int64_t batch_count,
int64_t key_seq_len) {
constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4;
int key_seq_len_pow_2 = key_seq_len * key_seq_len;
int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len;

int first_idx =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize +
int64_t first_idx =
(static_cast<int64_t>(blockDim.y) * blockIdx.y + threadIdx.y) *
gridDim.x * kLocalBatchSize +
blockIdx.x;
int local_block_idx = blockIdx.x + 1;
int warp_iter_upper_bound =
int64_t local_block_idx = blockIdx.x + 1;
int64_t warp_iter_upper_bound =
(local_block_idx + kOneLoadingCounts * warp_size - 1) / warp_size;

int local_batches = batch_count - first_idx;
int64_t local_batches = batch_count - first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;

int local_idx = threadIdx.x;
int64_t local_idx = threadIdx.x;

src += first_idx * key_seq_len + kOneLoadingCounts * local_idx;
dst += first_idx * key_seq_len + kOneLoadingCounts * local_idx;
Expand All @@ -157,11 +158,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,

#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_total_number = (i >= local_batches) ? 0 : local_block_idx;
auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx;

#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;

if (element_index < batch_total_number) {
load_data_upper_tri(temp_in,
Expand Down Expand Up @@ -216,7 +217,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,
if (i >= local_batches) break;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;

if (element_index < local_block_idx) {
#pragma unroll
Expand All @@ -242,31 +243,32 @@ template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,
T* grad_output,
const T* softmax_rst,
int batch_count,
int key_seq_len) {
int64_t batch_count,
int64_t key_seq_len) {
constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4;
int key_seq_len_pow_2 = key_seq_len * key_seq_len;
int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len;

int first_idx =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize +
int64_t first_idx =
(static_cast<int64_t>(blockDim.y) * blockIdx.y + threadIdx.y) *
gridDim.x * kLocalBatchSize +
blockIdx.x;
int local_block_idx = blockIdx.x + 1;
int64_t local_block_idx = blockIdx.x + 1;

// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_count - first_idx;
int64_t local_batches = batch_count - first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;

// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
int64_t local_idx = threadIdx.x;

// the first element to process by the current thread
int offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx;
int64_t offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx;
grad_input += offset;
grad_output += offset;
softmax_rst += offset;
Expand All @@ -279,11 +281,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,

#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_total_number = (i >= local_batches) ? 0 : local_block_idx;
auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx;

#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < batch_total_number) {
load_data_upper_tri(
temp_grad_input,
Expand Down Expand Up @@ -328,7 +330,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,
if (i >= local_batches) break;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < key_seq_len) {
// compute gradients
T samples_out[kOneLoadingCounts];
Expand Down Expand Up @@ -369,10 +371,10 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
key_seq_len,
query_seq_len));

PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192,
PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len <= 16384,
true,
platform::errors::InvalidArgument(
"Input x's last dim must be between [32, 8192) "
"Input x's last dim must be between [32, 16384] "
"received the last dimension of x is %d",
key_seq_len));

Expand All @@ -381,7 +383,7 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {

int pow2_index = get_pow2_index_value(key_seq_len);
const int next_pow2 = 1 << pow2_index;
int batch_count = attn_mul_batch * query_seq_len;
int64_t batch_count = attn_mul_batch * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
Expand Down Expand Up @@ -448,7 +450,13 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
<<<blocks, threads, 0, stream>>>(
x_data, y_data, batch_count, key_seq_len);
break;
case 14: // 16384
SoftmaxMaskFuseUpperTriangleGPUKernel<T, 14>
<<<blocks, threads, 0, stream>>>(
x_data, y_data, batch_count, key_seq_len);
break;
default:
PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length."));
break;
}
}
Expand Down Expand Up @@ -478,7 +486,7 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> {

int pow2_index = get_pow2_index_value(key_seq_len);
const int next_pow2 = 1 << pow2_index;
int batch_count = attn_mul_batch * query_seq_len;
int64_t batch_count = attn_mul_batch * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
// use 128 threads per block to maximum gpu utilization
Expand Down Expand Up @@ -564,7 +572,16 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> {
batch_count,
key_seq_len);
break;
case 14:
SoftmaxMaskFuseUpperTriangleGradGPUKernel<T, 14>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
default:
PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length."));
break;
}
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/funcs/elementwise_base.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -760,8 +760,10 @@ __global__ void VectorizedElementwiseKernel(
kps::IndexType main_offset,
int read_lens,
Functor func) {
kps::IndexType data_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
kps::IndexType stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
kps::IndexType data_offset =
static_cast<kps::IndexType>(BLOCK_ID_X) * BLOCK_NUM_X * read_lens;
kps::IndexType stride =
static_cast<kps::IndexType>(BLOCK_NUM_X) * GRID_NUM_X * read_lens;
for (; data_offset < main_offset; data_offset += stride) {
VectorizedElementwiseKernelImpl<OutT,
Functor,
Expand Down

0 comments on commit a053932

Please sign in to comment.