Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some operators when the tensor.numel() > INT32_MAX #46767

Merged
merged 2 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 45 additions & 28 deletions paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,26 +127,27 @@ __device__ __forceinline__ void warp_reduce_upper_tri(T* sum) {
template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,
T* dst,
int batch_count,
int key_seq_len) {
int64_t batch_count,
int64_t key_seq_len) {
constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4;
int key_seq_len_pow_2 = key_seq_len * key_seq_len;
int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len;

int first_idx =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize +
int64_t first_idx =
(static_cast<int64_t>(blockDim.y) * blockIdx.y + threadIdx.y) *
gridDim.x * kLocalBatchSize +
blockIdx.x;
int local_block_idx = blockIdx.x + 1;
int warp_iter_upper_bound =
int64_t local_block_idx = blockIdx.x + 1;
int64_t warp_iter_upper_bound =
(local_block_idx + kOneLoadingCounts * warp_size - 1) / warp_size;

int local_batches = batch_count - first_idx;
int64_t local_batches = batch_count - first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;

int local_idx = threadIdx.x;
int64_t local_idx = threadIdx.x;

src += first_idx * key_seq_len + kOneLoadingCounts * local_idx;
dst += first_idx * key_seq_len + kOneLoadingCounts * local_idx;
Expand All @@ -156,11 +157,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,

#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_total_number = (i >= local_batches) ? 0 : local_block_idx;
auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx;

#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;

if (element_index < batch_total_number) {
load_data_upper_tri(temp_in,
Expand Down Expand Up @@ -215,7 +216,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,
if (i >= local_batches) break;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;

if (element_index < local_block_idx) {
#pragma unroll
Expand All @@ -241,31 +242,32 @@ template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,
T* grad_output,
const T* softmax_rst,
int batch_count,
int key_seq_len) {
int64_t batch_count,
int64_t key_seq_len) {
constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4;
int key_seq_len_pow_2 = key_seq_len * key_seq_len;
int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len;

int first_idx =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize +
int64_t first_idx =
(static_cast<int64_t>(blockDim.y) * blockIdx.y + threadIdx.y) *
gridDim.x * kLocalBatchSize +
blockIdx.x;
int local_block_idx = blockIdx.x + 1;
int64_t local_block_idx = blockIdx.x + 1;

// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_count - first_idx;
int64_t local_batches = batch_count - first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;

// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
int64_t local_idx = threadIdx.x;

// the first element to process by the current thread
int offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx;
int64_t offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx;
grad_input += offset;
grad_output += offset;
softmax_rst += offset;
Expand All @@ -278,11 +280,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,

#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_total_number = (i >= local_batches) ? 0 : local_block_idx;
auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx;

#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < batch_total_number) {
load_data_upper_tri(
temp_grad_input,
Expand Down Expand Up @@ -327,7 +329,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,
if (i >= local_batches) break;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < key_seq_len) {
// compute gradients
T samples_out[kOneLoadingCounts];
Expand Down Expand Up @@ -368,10 +370,10 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
key_seq_len,
query_seq_len));

PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192,
PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len <= 16384,
true,
platform::errors::InvalidArgument(
"Input x's last dim must be between [32, 8192) "
"Input x's last dim must be between [32, 16384] "
"received the last dimension of x is %d",
key_seq_len));

Expand All @@ -380,7 +382,7 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {

int pow2_index = get_pow2_index_value(key_seq_len);
const int next_pow2 = 1 << pow2_index;
int batch_count = attn_mul_batch * query_seq_len;
int64_t batch_count = attn_mul_batch * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
Expand Down Expand Up @@ -447,7 +449,13 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
<<<blocks, threads, 0, stream>>>(
x_data, y_data, batch_count, key_seq_len);
break;
case 14: // 16384
SoftmaxMaskFuseUpperTriangleGPUKernel<T, 14>
<<<blocks, threads, 0, stream>>>(
x_data, y_data, batch_count, key_seq_len);
break;
default:
PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length."));
break;
}
}
Expand Down Expand Up @@ -479,7 +487,7 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> {

int pow2_index = get_pow2_index_value(key_seq_len);
const int next_pow2 = 1 << pow2_index;
int batch_count = attn_mul_batch * query_seq_len;
int64_t batch_count = attn_mul_batch * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
// use 128 threads per block to maximum gpu utilization
Expand Down Expand Up @@ -565,7 +573,16 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> {
batch_count,
key_seq_len);
break;
case 14:
SoftmaxMaskFuseUpperTriangleGradGPUKernel<T, 14>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
default:
PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length."));
break;
}
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/funcs/elementwise_base.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -760,8 +760,10 @@ __global__ void VectorizedElementwiseKernel(
kps::IndexType main_offset,
int read_lens,
Functor func) {
kps::IndexType data_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
kps::IndexType stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
kps::IndexType data_offset =
static_cast<kps::IndexType>(BLOCK_ID_X) * BLOCK_NUM_X * read_lens;
kps::IndexType stride =
static_cast<kps::IndexType>(BLOCK_NUM_X) * GRID_NUM_X * read_lens;
for (; data_offset < main_offset; data_offset += stride) {
VectorizedElementwiseKernelImpl<OutT,
Functor,
Expand Down