diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 4aa8b65635e7a..782c5d70ee077 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -284,11 +284,30 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { P* d_layernorm_bias, T* d_dropout_src, T* d_bias, T* d_residual) { using U = LayerNormParamType; - LayerNormBackward( - layernorm_src, d_out, gamma, mean, variance, d_layernorm_src, d_scale, - d_layernorm_bias, epsilon_, this->rows_, this->cols_, ctx); - this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src, - d_residual, d_bias); + bool can_call_1024_kernel = false; + // Fast impl for cases when cols is 1024 and linear_bias is nullptr. + // In fact, linear_bias is not nullptr is also feasible for impl. + // Here, we do not support it. + if (this->cols_ == 1024 && d_bias == nullptr && d_scale != nullptr && + d_layernorm_bias != nullptr && sizeof(T) <= 4) { + can_call_1024_kernel = true; + } + VLOG(6) << "LaunchLayernormResidualDropoutGrad = " << can_call_1024_kernel; + + if (can_call_1024_kernel) { + LaunchLayernormResidualDropoutGrad( + ctx, this->rows_, this->cols_, epsilon_, + this->dropout_param_.dropout_prob, + this->dropout_param_.is_upscale_in_train, d_out, layernorm_src, gamma, + mean, variance, mask, d_scale, d_layernorm_bias, d_residual, + d_dropout_src); + } else { + LayerNormBackward( + layernorm_src, d_out, gamma, mean, variance, d_layernorm_src, d_scale, + d_layernorm_bias, epsilon_, this->rows_, this->cols_, ctx); + this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src, + d_residual, d_bias); + } } protected: diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index babf1c657f232..911c2cda57504 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -441,5 +441,30 @@ void LaunchLayernormResidualDropoutBias( } } +template +void LaunchLayernormResidualDropoutGrad( + const platform::CUDADeviceContext &dev_ctx, const uint32_t rows, + const uint32_t cols, const float epsilon, const float dropout_prob, + const bool is_upscale_in_train, const T *d_out, const T *layernorm_src, + const LayerNormScaleBiasT *scale, + const LayerNormParamType *mean, const LayerNormParamType *var, + const MaskType *mask_data, + LayerNormScaleBiasT *d_scale, + LayerNormScaleBiasT *d_layernorm_bias, + T *d_residual, T *d_dropout_src) { + const T zero = static_cast(0.0f); + auto factor = dropout_prob == static_cast(1.0f) + ? zero + : static_cast(1.0f / (1.0f - dropout_prob)); + if (!is_upscale_in_train) { + factor = static_cast(1.0f); + } + ln_bwd_1024_kernel_driver< + T, U, LayerNormScaleBiasT, MaskType>( + dev_ctx, rows, cols, epsilon, layernorm_src, scale, mean, var, d_out, + d_residual, d_scale, d_layernorm_bias, mask_data, factor, d_dropout_src); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index bc00d875cd1dd..da4932543a0e6 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -385,6 +385,471 @@ __inline__ __device__ void cuLoadAddStridedInputs( } } +#ifdef PADDLE_WITH_CUDA +template < + bool isFusedDropoutResidualLn, typename T, typename U, typename ScaleT = U, + typename MaskType = uint8_t, int VecSize = 8, int WARPS_M = 4, + int WARPS_N = 1, int BYTES_PER_LDG = 16, int ELTS_PER_ROW = 1024, + int THREADS_PER_WARP = 32, int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP, + int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M, + int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize, + int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA> +__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( + const int rows, float epsilon, const T *__restrict__ x_ptr, + const ScaleT *__restrict__ gamma_ptr, const U *__restrict__ mean_ptr, + const U *__restrict__ var_ptr, const T *__restrict__ dout_ptr, + U *__restrict__ dgamma_temp_ptr, U *__restrict__ dbeta_temp_ptr, + T *__restrict__ dx_ptr, const MaskType *mask_ptr = nullptr, + T factor = static_cast(0), T *d_dropout_src_ptr = nullptr) { + using Vec = platform::AlignedVector; + using Vec_scale = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int lane = tidx % THREADS_PER_WARP; // 0, 1, ..., 31 + const int warp = tidx / THREADS_PER_WARP; // 0, 1, 2, 3 + const int warp_m = warp / WARPS_N; // 0, 1, 2, 3 + const int warp_n = warp % WARPS_N; // 0 + const int tid_r = warp_n * THREADS_PER_WARP + lane; // 0, 1, ..., 31 + + const int r = bidx * ROWS_PER_CTA + warp_m; + const int c = warp_n * THREADS_PER_WARP + lane; + + static_assert(LN_NUM_COLS == THREADS_PER_ROW * LDGS * VecSize, ""); + + // smem for column reduction + __shared__ U smem_[ROWS_PER_CTA * LN_NUM_COLS]; + + U dgamma_sum[LDGS * VecSize]; + U dbeta_sum[LDGS * VecSize]; + + memset(dgamma_sum, 0, sizeof(U) * LDGS * VecSize); + memset(dbeta_sum, 0, sizeof(U) * LDGS * VecSize); + + // Note: it is no use for WARP_N = 1 + __shared__ U smem_sum_loss1[ROWS_PER_CTA * WARPS_N]; // 4 + __shared__ U smem_sum_loss2[ROWS_PER_CTA * WARPS_N]; // 4 + U *sum_loss1_shared = &smem_sum_loss1[warp_m * WARPS_N]; + U *sum_loss2_shared = &smem_sum_loss2[warp_m * WARPS_N]; + + // step-1: compute dx and local results of dscale and dbias + constexpr float rn = 1.f / static_cast(LN_NUM_COLS); + Vec_scale gamma[LDGS]; + int col = c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + platform::Load(gamma_ptr + col * VecSize, &gamma[it]); + col += THREADS_PER_ROW; + } + +#pragma unroll 1 + for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) { + const U mean_cur_row = mean_ptr[row]; + const U var_cur_row = rsqrt_(var_ptr[row] + epsilon); + Vec dout[LDGS], x[LDGS]; + MaskLoadT mask_vec[LDGS]; + int col = c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + platform::Load(dout_ptr + row * LN_NUM_COLS + col * VecSize, + &dout[it]); + platform::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, + &x[it]); + if (isFusedDropoutResidualLn) { + platform::Load( + mask_ptr + row * LN_NUM_COLS + col * VecSize, &mask_vec[it]); + } + + col += THREADS_PER_ROW; + } + + // local reductions + U dy[LDGS * VecSize]; + U y[LDGS * VecSize]; + + U sum_loss1 = 0.f; + U sum_loss2 = 0.f; +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + U x_tmp = x[it][jt]; + U y_tmp = var_cur_row * (x_tmp - mean_cur_row); + U dy_tmp = static_cast(gamma[it][jt]) * + static_cast(dout[it][jt]); // scale * dy + U dout_tmp = dout[it][jt]; // dy + + // used for get dx (row reduction) + sum_loss1 += dy_tmp; // scale * dy, sum_1 + sum_loss2 += dy_tmp * y_tmp; // scale * dy * y, sum_2 + + dy[it * VecSize + jt] = dy_tmp; // scale * dy + y[it * VecSize + jt] = y_tmp; // y + + // used for get dscale and dbias (column reduction) + dgamma_sum[it * VecSize + jt] += dout_tmp * y_tmp; // dy * y + dbeta_sum[it * VecSize + jt] += dout_tmp; // dy + } + } + + // reduction across row for sum_loss1, sum_loss2 + if (WARPS_N == 1) { +#pragma unroll + // row reduction among 32 threads. + for (int it = 1; it < THREADS_PER_WARP; it *= 2) { + sum_loss1 += __shfl_xor_sync(uint32_t(-1), sum_loss1, it); + sum_loss2 += __shfl_xor_sync(uint32_t(-1), sum_loss2, it); + } + sum_loss1 *= rn; + sum_loss2 *= rn; + } else { +#pragma unroll + for (int it = 16; it > 0; it /= 2) { + sum_loss1 += __shfl_down_sync(uint32_t(-1), sum_loss1, it); + sum_loss2 += __shfl_down_sync(uint32_t(-1), sum_loss2, it); + } + + if (lane == 0) { + sum_loss1_shared[warp_n] = sum_loss1; + sum_loss2_shared[warp_n] = sum_loss2; + } + + __syncthreads(); + if (warp_n == 0 && lane == 0) { + sum_loss1 = 0.f; + sum_loss2 = 0.f; + for (int it = 0; it < WARPS_N; it++) { + sum_loss1 += sum_loss1_shared[it]; + sum_loss2 += sum_loss2_shared[it]; + } + sum_loss1_shared[0] = sum_loss1; + sum_loss2_shared[0] = sum_loss2; + } + __syncthreads(); + + sum_loss1 = sum_loss1_shared[0] * rn; + sum_loss2 = sum_loss2_shared[0] * rn; + } + +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + U dy_tmp = dy[it * VecSize + jt]; // scale * dy + U y_tmp = y[it * VecSize + jt]; // y + // dx = var * (scale * dy - sum_loss2 * y - sum_loss1) + U dx_tmp = var_cur_row * (dy_tmp - sum_loss2 * y_tmp - sum_loss1); + // Note: reuse x and dout vec register to store dx and d_dropout_src. + x[it][jt] = static_cast(dx_tmp); + if (isFusedDropoutResidualLn) { + dout[it][jt] = x[it][jt] * static_cast(mask_vec[it][jt]) * factor; + } + } + } + + // store dx to global memory + col = c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + platform::Store(x[it], + dx_ptr + row * LN_NUM_COLS + col * VecSize); + if (isFusedDropoutResidualLn) { + platform::Store( + dout[it], d_dropout_src_ptr + row * LN_NUM_COLS + col * VecSize); + } + col += THREADS_PER_ROW; + } + } + + // step-2: column reduction of dscale and dbias for each thread block. + // each block's sum: [4 * 1024] -> [1 * 1024] + enum { NUM_RES = LN_NUM_COLS / THREADS_PER_CTA }; // 1024/128 = 8 + static_assert(NUM_RES * THREADS_PER_CTA == LN_NUM_COLS, ""); + + U *smem_write; + + smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize]; // [4 * 1024] +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + smem_write[jt] = dbeta_sum[it * VecSize + jt]; + } + smem_write += THREADS_PER_ROW * VecSize; // 32*8 + } + __syncthreads(); + U cta_dbeta_sum[NUM_RES]; + memset(cta_dbeta_sum, 0, sizeof(U) * NUM_RES); + // column reduction for elems in smem: 4*1024 -> 1*1024. + for (int it = 0; it < ROWS_PER_CTA; it++) { + for (int jt = 0; jt < NUM_RES; jt++) { + cta_dbeta_sum[jt] += + smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA]; + } + } + __syncthreads(); + + smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize]; +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + smem_write[jt] = dgamma_sum[it * VecSize + jt]; + } + smem_write += THREADS_PER_ROW * VecSize; + } + __syncthreads(); + U cta_dgamma_sum[NUM_RES]; + memset(cta_dgamma_sum, 0, sizeof(U) * NUM_RES); + for (int it = 0; it < ROWS_PER_CTA; it++) { + for (int jt = 0; jt < NUM_RES; jt++) { + cta_dgamma_sum[jt] += + smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA]; + } + } + + // the shape of results:(#blocks, 1024) + U *dgamma_part = + static_cast(dgamma_temp_ptr) + bidx * LN_NUM_COLS + tidx; + for (int jt = 0; jt < NUM_RES; jt++) { + *dgamma_part = cta_dgamma_sum[jt]; + dgamma_part += THREADS_PER_CTA; + } + + U *dbeta_part = static_cast(dbeta_temp_ptr) + bidx * LN_NUM_COLS + tidx; + for (int jt = 0; jt < NUM_RES; jt++) { + *dbeta_part = cta_dbeta_sum[jt]; + dbeta_part += THREADS_PER_CTA; + } +} + +/* This function carry out column reduction whose input is [rows, 1024] and + * output is [1, 1024]. + * #blocks: 32 + * #threads: 512 +*/ +// todo(@limin29): to think if there are better impl strategies +template < + typename U, typename ScaleT = U, int VecSize = 1, int WARPS_M = 16, + int WARPS_N = 1, int BYTES_PER_LDG = 4, int ELTS_PER_ROW = 1024, + int THREADS_PER_WARP = 32, int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP, + int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M, + int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize, + int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA, + int VEC_COLS = ELTS_PER_ROW / VecSize> +__global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( + const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_, + ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) { + using Vec = platform::AlignedVector; + static_assert(VEC_COLS == LN_NUM_COLS / VecSize, ""); + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int lane = tidx % THREADS_PER_WARP; + const int warp = tidx / THREADS_PER_WARP; + const int warp_m = warp / WARPS_N; + const int warp_n = warp % WARPS_N; + const int tid_c = warp_n * THREADS_PER_WARP + lane; + + const int c = bidx * THREADS_PER_ROW + tid_c; + const int r = warp_m; + + __shared__ U smem_space[(WARPS_M - 1) * THREADS_PER_ROW * VecSize]; + + for (int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW) { + const U *dg_part_ptr = (dg_part_) + r * LN_NUM_COLS + col * VecSize; + const U *db_part_ptr = (db_part_) + r * LN_NUM_COLS + col * VecSize; + + U dg_sum[VecSize]; + U db_sum[VecSize]; + memset(dg_sum, 0, sizeof(U) * VecSize); + memset(db_sum, 0, sizeof(U) * VecSize); +#pragma unroll + for (int row = r; row < rows; row += ROWS_PER_CTA) { + Vec dg; + Vec db; + platform::Load(dg_part_ptr, &dg); + platform::Load(db_part_ptr, &db); + dg_part_ptr += ROWS_PER_CTA * LN_NUM_COLS; + db_part_ptr += ROWS_PER_CTA * LN_NUM_COLS; + +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + dg_sum[jt] += dg[jt]; + db_sum[jt] += db[jt]; + } + } + + // reduction across rows of the thread block + U *smem_write; + smem_write = smem_space + (warp_m - 1) * THREADS_PER_ROW * VecSize + tid_c; + + if (warp_m > 0) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + *smem_write = dg_sum[jt]; + smem_write += THREADS_PER_ROW; + } + } + __syncthreads(); + + U *smem_read; + smem_read = smem_space + tid_c; + if (warp_m == 0) { +#pragma unroll + for (int it = 0; it < WARPS_M - 1; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + dg_sum[jt] += *smem_read; + smem_read += THREADS_PER_ROW; + } + } + } + + __syncthreads(); + + smem_write = smem_space + (warp_m - 1) * THREADS_PER_ROW * VecSize + tid_c; + + if (warp_m > 0) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + *smem_write = db_sum[jt]; + smem_write += THREADS_PER_ROW; + } + } + __syncthreads(); + + smem_read = smem_space + tid_c; + if (warp_m == 0) { +#pragma unroll + for (int it = 0; it < WARPS_M - 1; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + db_sum[jt] += *smem_read; + smem_read += THREADS_PER_ROW; + } + } + + union { + ScaleT raw; + ScaleT elt[VecSize]; + } dg_out, db_out; + +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + dg_out.elt[jt] = dg_sum[jt]; + db_out.elt[jt] = db_sum[jt]; + } + ScaleT *dg_ptr = reinterpret_cast(dg_) + col; + ScaleT *db_ptr = reinterpret_cast(db_) + col; + *dg_ptr = dg_out.raw; + *db_ptr = db_out.raw; + } + } +} + +/* This function support two kinds of computations (only for float and fp16 +* type): +* +* Case-1: compute layer_norm_grad for layernorm op by setting mask_ptr and +* d_dropout_src_ptr to nullptr. Here, d_x_ptr returns the grad of layernorm +* input. +* +* Case-2: compute layer_norm_grad + residual_grad + dropout_grad for +* fused_dropout_residual_layernorm op. Here, dx_ptr returns residual_grad. +* +*/ +template +void ln_bwd_1024_kernel_driver( + const platform::CUDADeviceContext &dev_ctx, const int rows, const int cols, + float epsilon, const T *x_ptr, const ScaleT *scale_ptr, const U *mean_ptr, + const U *var_ptr, const T *dout_ptr, T *dx_ptr, ScaleT *dscale_ptr, + ScaleT *dbias_ptr, const MaskType *mask_ptr = nullptr, + T factor = static_cast(0), T *d_dropout_src_ptr = nullptr) { + auto stream = dev_ctx.stream(); + if (cols == 1024) { + // step-1: compute dx and reduced part results of dscale and dbias. + const int WARPS_M = 4; + const int WARPS_N = 1; + const int BYTES_PER_LDG = 16; + const int VecSize = BYTES_PER_LDG / sizeof(T); + + const int THREADS_PER_WARP = 32; + const int THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP; + const int THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW; + const int ROWS_PER_CTA = WARPS_M; + + // 4 * 1024 * 4 + const int SMEM_BYTES = ROWS_PER_CTA * cols * sizeof(U); + + // #blocks = 2 * #SM + const int gridx = 2 * dev_ctx.GetSMCount(); + + // get temp space for dscale and dbias. + framework::Tensor dscale_temp; + dscale_temp.Resize({gridx, cols}); + dscale_temp.mutable_data(dev_ctx.GetPlace()); + U *dscale_temp_ptr = dscale_temp.data(); + + framework::Tensor dbias_temp; + dbias_temp.Resize({gridx, cols}); + dbias_temp.mutable_data(dev_ctx.GetPlace()); + U *dbias_temp_ptr = dbias_temp.data(); + + if (mask_ptr != nullptr) { + if (d_dropout_src_ptr == nullptr) { + PADDLE_THROW(platform::errors::InvalidArgument( + "To compute fused_dropout_residual_ln grad, d_dropout_src_ptr " + "can't be null")); + } + fused_ln_bwd_1024_kernel< + true, T, U, ScaleT, MaskType, VecSize, WARPS_M, WARPS_N, + BYTES_PER_LDG><<>>( + rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, + dscale_temp_ptr, dbias_temp_ptr, dx_ptr, mask_ptr, factor, + d_dropout_src_ptr); + + } else { + fused_ln_bwd_1024_kernel< + false, T, U, ScaleT, MaskType, VecSize, WARPS_M, WARPS_N, + BYTES_PER_LDG><<>>( + rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, + dscale_temp_ptr, dbias_temp_ptr, dx_ptr); + } + const int WARPS_M_2 = 16; + const int WARPS_N_2 = 1; + const int BYTES_PER_LDG_2 = 4; + const int VecSize_2 = + std::max(1, static_cast(BYTES_PER_LDG_2 / sizeof(U))); // 1 + + const int THREADS_PER_WARP_2 = 32; + const int THREADS_PER_ROW_2 = WARPS_N_2 * THREADS_PER_WARP_2; // 32 + const int THREADS_PER_CTA_2 = + WARPS_M_2 * THREADS_PER_ROW_2; // 16 * 32 = 512 + const int ROWS_PER_CTA_2 = WARPS_M_2; // 16 + + const int gridx_2 = static_cast( + std::ceil(1024 / static_cast(THREADS_PER_ROW_2 * VecSize_2))); + // #blocks: 32,#threads_per_block: 512 + // Note: it is not supported for double type. + if (sizeof(U) > 4) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support float and fp16 type")); + } else { + ln_bwd_1024_final_kernel< + U, ScaleT, VecSize_2, WARPS_M_2, WARPS_N_2, + BYTES_PER_LDG_2><<>>( + gridx, dscale_temp_ptr, dbias_temp_ptr, dscale_ptr, dbias_ptr); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Fast layer_norm kernel is only used when feature_size is 1024")); + } +} +#endif + template __global__ void LayerNormBackwardPartGradGammaBeta( const T *__restrict__ dout, const T *__restrict__ input, const int64_t n1, @@ -983,42 +1448,62 @@ static void LayerNormBackward( break; case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr { - constexpr int VPT = 4; - constexpr int BDIMX2 = 32; - constexpr int BDIMY2 = 4; - dim3 threads2(BDIMX2, BDIMY2, 1); - constexpr int part_size = BDIMY2 * VPT; - const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1); - - auto part_grad_gamma_ptr = - memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U)); - auto part_grad_beta_ptr = - memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U)); - U *part_grad_gamma = reinterpret_cast(part_grad_gamma_ptr->ptr()); - U *part_grad_beta = reinterpret_cast(part_grad_beta_ptr->ptr()); - - LayerNormBackwardPartGradGammaBeta<<>>( - d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma, - part_grad_beta); // compute part_grad_gamma, beta - - constexpr int BDIMX3 = 32; - constexpr int BDIMY3 = 8; - dim3 threads3(BDIMX3, BDIMY3, 1); - const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1); - LayerNormBackwardSumGradGammaBeta< - T, U, BDIMX3, BDIMY3, - ScaleBiasWithSameTypeX><<>>( - part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size, - d_scale, d_bias); - - constexpr int BDIMX1 = 32; - constexpr int BDIMY1 = 4; - dim3 threads1(BDIMX1, BDIMY1, 1); - LayerNormBackwardComputeGradInput< - T, U, BDIMX1, BDIMY1, - ScaleBiasWithSameTypeX><<>>( - d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x); +#ifdef PADDLE_WITH_CUDA + bool can_call_1024_kernel = false; + // todo: rule out double type. + if (feature_size == 1024 && sizeof(T) <= 4) { + can_call_1024_kernel = true; + } + VLOG(6) << "can_call_1024_kernel = " << can_call_1024_kernel; + + if (can_call_1024_kernel) { + ln_bwd_1024_kernel_driver< + T, U, LayerNormScaleBiasT>( + dev_ctx, batch_size, feature_size, epsilon, x, scale, mean, var, + d_y, d_x, d_scale, d_bias); + } else { +#endif + constexpr int VPT = 4; + constexpr int BDIMX2 = 32; + constexpr int BDIMY2 = 4; + dim3 threads2(BDIMX2, BDIMY2, 1); + constexpr int part_size = BDIMY2 * VPT; + const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1); + + auto part_grad_gamma_ptr = + memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U)); + auto part_grad_beta_ptr = + memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U)); + U *part_grad_gamma = reinterpret_cast(part_grad_gamma_ptr->ptr()); + U *part_grad_beta = reinterpret_cast(part_grad_beta_ptr->ptr()); + + LayerNormBackwardPartGradGammaBeta< + T, U, BDIMX2, BDIMY2, VPT><<>>( + d_y, x, batch_size, feature_size, mean, var, epsilon, + part_grad_gamma, + part_grad_beta); // compute part_grad_gamma, beta + + constexpr int BDIMX3 = 32; + constexpr int BDIMY3 = 8; + dim3 threads3(BDIMX3, BDIMY3, 1); + const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1); + LayerNormBackwardSumGradGammaBeta< + T, U, BDIMX3, BDIMY3, + ScaleBiasWithSameTypeX><<>>( + part_grad_gamma, part_grad_beta, part_size, batch_size, + feature_size, d_scale, d_bias); + + constexpr int BDIMX1 = 32; + constexpr int BDIMY1 = 4; + dim3 threads1(BDIMX1, BDIMY1, 1); + LayerNormBackwardComputeGradInput< + T, U, BDIMX1, BDIMY1, + ScaleBiasWithSameTypeX><<>>( + d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x); +#ifdef PADDLE_WITH_CUDA + } +#endif + break; } default: