From 8bfd45adb14e47b9f157050c36438064175fbd77 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Mon, 17 Oct 2022 10:41:21 +0800 Subject: [PATCH] [Cherry-Pick]Move valid check from python to kernel (#46980) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为了提升性能,将label的边界检查从python端转移到kernel内,减少额外op的调用,如min、max和同步拷贝等 当前的模板参数IgnoreIndex仅在ignore_index取值范围在[0, dim)时才生效,但是当某个label值超出了边界,ignore_index等于该label,这种情况下是应该仍然能正常计算。虽然当前的计算逻辑在结果上不会出错,但逻辑上仍是有问题的,且模板参数IgnoreIndex是没有必要的 --- .../phi/kernels/gpu/cross_entropy_kernel.cu | 392 ++++++++---------- python/paddle/nn/functional/loss.py | 8 - 2 files changed, 164 insertions(+), 236 deletions(-) diff --git a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu index 1a4559d5cd6b5..76201a1077edb 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu @@ -170,7 +170,7 @@ __global__ void CrossEntropySoftLabel(T* loss, /* Hard label cross entropy. */ -template +template __global__ void CrossEntropyHardLabel(T* loss, const T* softmax, const LabelT* labels, @@ -185,21 +185,17 @@ __global__ void CrossEntropyHardLabel(T* loss, // thread ids compute loss[ids] using softmax[idx] if (ids < n * d) { auto lbl = static_cast(labels[ids]); - if (lbl < 0) { // label is negative + PADDLE_ENFORCE(lbl >= 0 && lbl < dim || lbl == ignore_idx, + "The value of label expected >= 0 and < %d, or == %d, " + "but got %ld. Please check label value.", + dim, + ignore_idx, + lbl); + if (lbl == ignore_idx) { loss[ids] = static_cast(0.0); - } else { // label is positive of zero + } else { int64_t idx = idx_n * dim * d + lbl * d + idx_d; - if (IgnoreIndex == true) { - // IgnoreIndex is true - if (lbl == ignore_idx) { - loss[ids] = static_cast(0.0); - } else { - loss[ids] = -Log(softmax[idx]); - } - } else { - // IgnoreIndex is false - loss[ids] = -Log(softmax[idx]); - } + loss[ids] = -Log(softmax[idx]); } } } @@ -209,7 +205,7 @@ __global__ void CrossEntropyHardLabel(T* loss, Input: log softmax Output: loss and exp(input) */ -template +template __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, const LabelT* labels, @@ -225,23 +221,17 @@ __global__ void CrossEntropyExpHardLabel(T* loss, if (idx < n * dim * d) { auto lbl = static_cast(labels[ids]); - if (IgnoreIndex == true) { - // IgnoreIndex is true - if (idx_dim == lbl) { - if (lbl == ignore_idx) { - loss[ids] = static_cast(0.0); - } else { - loss[ids] = -softmax[idx]; - } - } + PADDLE_ENFORCE(lbl >= 0 && lbl < dim || lbl == ignore_idx, + "The value of label expected >= 0 and < %d, or == %d, " + "but got %ld. Please check label value.", + dim, + ignore_idx, + lbl); + if (lbl == ignore_idx) { + loss[ids] = static_cast(0.0); } else { - // IgnoreIndex is false - if (lbl >= 0 && lbl < dim) { - if (lbl == idx_dim) { - loss[ids] = -softmax[idx]; - } - } else { - loss[ids] = static_cast(0.0); + if (lbl == idx_dim) { + loss[ids] = -softmax[idx]; } } softmax[idx] = Exp(softmax[idx]); @@ -290,7 +280,7 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input, return val; } -template +template __device__ __forceinline__ void ComputeLoss(T* loss, const T loss_value, const int label_id, @@ -300,14 +290,8 @@ __device__ __forceinline__ void ComputeLoss(T* loss, const int offset, const int ignore_index) { int loss_id = vec_size * tid + offset; - if (IgnoreIndex) { - if (label_value == loss_id) { - if (label_value == ignore_index) { - loss[label_id] = static_cast(0.0f); - } else { - loss[label_id] = loss_value; - } - } + if (label_value == ignore_index) { + loss[label_id] = static_cast(0.0f); } else { if (label_value == loss_id) { loss[label_id] = loss_value; @@ -315,11 +299,7 @@ __device__ __forceinline__ void ComputeLoss(T* loss, } } -template +template __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( T* loss, T* softmax, @@ -333,7 +313,13 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( int tid = threadIdx.x; int label_id = blockIdx.x; auto label_value = static_cast(label[label_id]); - const bool label_valid = label_value >= 0 && label_value < size; + PADDLE_ENFORCE( + label_value >= 0 && label_value < size || label_value == ignore_index, + "The value of label expected >= 0 and < %d, or == %d, " + "but got %ld. Please check label value.", + size, + ignore_index, + label_value); int loss_id_offset = 0; if (offset > 0) { @@ -345,16 +331,14 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( AccT log_softmax = func(static_cast(logits[tid])); softmax[tid] = static_cast(std::exp(log_softmax)); // loss - if (label_valid) { - ComputeLoss(loss, - static_cast(-log_softmax), - label_id, - label_value, - tid, - 1, - loss_id_offset, - ignore_index); - } + ComputeLoss(loss, + static_cast(-log_softmax), + label_id, + label_value, + tid, + 1, + loss_id_offset, + ignore_index); } size -= blockDim.x; logits += blockDim.x; @@ -380,16 +364,14 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( outs[i] = static_cast(std::exp(log_softmax)); // loss - if (label_valid) { - ComputeLoss(loss, - static_cast(-log_softmax), - label_id, - label_value, - tid, - VecSize, - loss_id_offset + i, - ignore_index); - } + ComputeLoss(loss, + static_cast(-log_softmax), + label_id, + label_value, + tid, + VecSize, + loss_id_offset + i, + ignore_index); } // write @@ -403,29 +385,18 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( softmax[tid] = static_cast(std::exp(log_softmax)); // loss - if (label_valid) { - ComputeLoss(loss, - static_cast(-log_softmax), - label_id, - label_value, - tid, - 1, - loss_id_offset, - ignore_index); - } - } - - // invalid label, write once - if (!label_valid && threadIdx.x == 0) { - loss[label_id] = static_cast(0.0f); + ComputeLoss(loss, + static_cast(-log_softmax), + label_id, + label_value, + tid, + 1, + loss_id_offset, + ignore_index); } } -template +template __device__ __forceinline__ void ScalarSoftmaxForwardImpl( T* loss, T* softmax, @@ -438,7 +409,13 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( int remain = size % (VecSize * blockDim.x); int label_id = blockIdx.x; auto label_value = static_cast(label[label_id]); - const bool label_valid = label_value >= 0 && label_value < size; + PADDLE_ENFORCE( + label_value >= 0 && label_value < size || label_value == ignore_index, + "The value of label expected >= 0 and < %d, or == %d, " + "but got %ld. Please check label value.", + size, + ignore_index, + label_value); // main part for (; tid < (size - remain); tid += VecSize * blockDim.x) { @@ -453,16 +430,14 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( AccT log_softmax = func(static_cast(ins[i])); softmax[tid + i * blockDim.x] = static_cast(std::exp(log_softmax)); // loss - if (label_valid) { - ComputeLoss(loss, - static_cast(-log_softmax), - label_id, - label_value, - tid, - VecSize, - i, - ignore_index); - } + ComputeLoss(loss, + static_cast(-log_softmax), + label_id, + label_value, + tid, + VecSize, + i, + ignore_index); } } @@ -471,29 +446,18 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( AccT log_softmax = func(static_cast(logits[tid])); softmax[tid] = static_cast(std::exp(log_softmax)); // loss - if (label_valid) { - ComputeLoss(loss, - static_cast(-log_softmax), - label_id, - label_value, - tid, - 1, - 0, - ignore_index); - } - } - - // invalid label, write once - if (!label_valid && threadIdx.x == 0) { - loss[label_id] = static_cast(0.0f); + ComputeLoss(loss, + static_cast(-log_softmax), + label_id, + label_value, + tid, + 1, + 0, + ignore_index); } } -template +template __global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits, @@ -533,17 +497,16 @@ __global__ void VectorizedSoftmaxForward(T* loss, // 3. softmax phi::LogSoftmaxForwardFunctor func(max, sum); if (input_offset == output_offset) { - VectorizedSoftmaxForwardImpl( - loss, - softmax, - logits, - label, - mid_dim, - input_offset, - func, - ignore_index); + VectorizedSoftmaxForwardImpl(loss, + softmax, + logits, + label, + mid_dim, + input_offset, + func, + ignore_index); } else { - ScalarSoftmaxForwardImpl( + ScalarSoftmaxForwardImpl( loss, softmax, logits, label, mid_dim, func, ignore_index); } } @@ -556,8 +519,8 @@ The computation includes - Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} - log(sum[i]))} One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). -For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle -api to compute max (sum) in one warp. +For reduction max (sum), firstly compute max (sum) to one warp, then use +shuffle api to compute max (sum) in one warp. */ template __global__ void WarpSoftmaxForwardSoftLabel(T* loss, @@ -876,8 +839,7 @@ template + SoftmaxMode mode> __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, @@ -1029,23 +991,21 @@ __global__ void WarpSoftmaxForward(T* loss, // label int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize; auto lbl = static_cast(label[first_batch + i]); - if (IgnoreIndex == true) { - // IgnoreIndex is true - if (lbl == loss_idx) { - if (lbl != ignore_index) { - loss[first_batch + i] = -logsoftmax; - } else { - loss[first_batch + i] = static_cast(0.0); - } - } + if (lbl == ignore_index) { + loss[first_batch + i] = static_cast(0.0); } else { - // IgnoreIndex is false if (lbl >= 0 && lbl < element_count) { if (lbl == loss_idx) { loss[first_batch + i] = -logsoftmax; } } else { - loss[first_batch + i] = static_cast(0.0); + PADDLE_ENFORCE( + false, + "The value of label expected >= 0 and < %d, or == %d, " + "but got %ld. Please check label value.", + element_count, + ignore_index, + lbl); } } } else { // softmax @@ -1072,19 +1032,21 @@ __global__ void WarpSoftmaxForward(T* loss, // label int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s; auto lbl = static_cast(label[first_batch + i]); - if (IgnoreIndex == true) { - // IgnoreIndex is true - if (lbl == loss_idx && lbl != ignore_index) { - loss[first_batch + i] = -logsoftmax; - } + if (lbl == ignore_index) { + loss[first_batch + i] = static_cast(0.0); } else { - // IgnoreIndex is false if (lbl >= 0 && lbl < element_count) { if (lbl == loss_idx) { loss[first_batch + i] = -logsoftmax; } } else { - loss[first_batch + i] = static_cast(0.0); + PADDLE_ENFORCE( + false, + "The value of label expected >= 0 and < %d, or == %d, " + "but got %ld. Please check label value.", + element_count, + ignore_index, + lbl); } } } else { // softmax @@ -1101,23 +1063,23 @@ __global__ void WarpSoftmaxForward(T* loss, } } -#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \ - case Log2Elements: \ - WarpSoftmaxForward \ - <<>>(loss, \ - softmax, \ - src, \ - label, \ - batch_size, \ - stride, \ - element_count, \ - ignore_index); \ +#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \ + case Log2Elements: \ + WarpSoftmaxForward \ + <<>>(loss, \ + softmax, \ + src, \ + label, \ + batch_size, \ + stride, \ + element_count, \ + ignore_index); \ break; /* Wrapper of softmax with cross entropy forward hard label. */ -template +template void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, @@ -1156,7 +1118,7 @@ void SwitchWarpSoftmaxForward(T* loss, } } -template +template void LaunchVectorizedSoftmaxForward(T* loss, T* softmax, const T* logits, @@ -1180,7 +1142,7 @@ void LaunchVectorizedSoftmaxForward(T* loss, block_size = std::max(block_size, kps::details::kWarpSize); dim3 grids(high_dim); dim3 blocks(block_size); - VectorizedSoftmaxForward + VectorizedSoftmaxForward <<>>( loss, softmax, logits, label, high_dim, mid_dim, ignore_index); } @@ -1191,7 +1153,7 @@ void LaunchVectorizedSoftmaxForward(T* loss, - LaunchVectorizedSoftmaxForward for large size when axis == -1 - cudnn function for axis != -1 */ -template +template static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, int rank, int axis, @@ -1208,24 +1170,24 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, if (D == 1) { if (dim <= max_dim) { // small size const SoftmaxMode mode = SoftmaxMode::kCrossEntropy; - SwitchWarpSoftmaxForward(loss_data, - softmax_data, - logits_data, - labels_data, - N, - dim, - dim, - ignore_index, - stream); + SwitchWarpSoftmaxForward(loss_data, + softmax_data, + logits_data, + labels_data, + N, + dim, + dim, + ignore_index, + stream); } else { // large size - LaunchVectorizedSoftmaxForward(loss_data, - softmax_data, - logits_data, - labels_data, - N, - dim, - ignore_index, - stream); + LaunchVectorizedSoftmaxForward(loss_data, + softmax_data, + logits_data, + labels_data, + N, + dim, + ignore_index, + stream); } } else { ScopedTensorDescriptor desc; @@ -1269,9 +1231,8 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, int threads = 128; int blocks = (N * dim * D + threads - 1) / threads; // compute cross entropy, input is log softmax - CrossEntropyExpHardLabel - <<>>( - loss_data, softmax_data, labels_data, N, dim, D, ignore_index); + CrossEntropyExpHardLabel<<>>( + loss_data, softmax_data, labels_data, N, dim, D, ignore_index); } } @@ -1367,25 +1328,14 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, auto* labels_data = labels.data(); int threads = 128; int blocks = (n * d / axis_dim + threads - 1) / threads; - if (ignore_index >= 0 && ignore_index < axis_dim) { - CrossEntropyHardLabel - <<>>(loss_data, - logits_data, - labels_data, - n, - axis_dim, - d / axis_dim, - ignore_index); - } else { - CrossEntropyHardLabel - <<>>(loss_data, - logits_data, - labels_data, - n, - axis_dim, - d / axis_dim, - ignore_index); - } + CrossEntropyHardLabel + <<>>(loss_data, + logits_data, + labels_data, + n, + axis_dim, + d / axis_dim, + ignore_index); } // cause of input is softmax @@ -1450,31 +1400,17 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, } else { auto* logits_data = logits.data(); auto* labels_data = label.data(); - if (ignore_index >= 0 && ignore_index < axis_dim) { - SoftmaxWithCrossEntropyHardLabel(dev_ctx, - rank, - axis_v, - logits_data, - labels_data, - loss_data, - softmax_data, - n, - axis_dim, - d / axis_dim, - ignore_index); - } else { - SoftmaxWithCrossEntropyHardLabel(dev_ctx, - rank, - axis_v, - logits_data, - labels_data, - loss_data, - softmax_data, - n, - axis_dim, - d / axis_dim, - ignore_index); - } + SoftmaxWithCrossEntropyHardLabel(dev_ctx, + rank, + axis_v, + logits_data, + labels_data, + loss_data, + softmax_data, + n, + axis_dim, + d / axis_dim, + ignore_index); } } } diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index c0742bdbf407f..fb9c4a56ab40a 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2386,14 +2386,6 @@ def cross_entropy(input, if soft_label == False: valid_label = paddle.cast(label != ignore_index, dtype=label.dtype) * label - label_min = paddle.min(valid_label) - label_max = paddle.max(valid_label) - if label_min < 0: - raise ValueError("Target {} is out of lower bound.".format( - label_min.item())) - if label_max >= input.shape[axis]: - raise ValueError("Target {} is out of upper bound.".format( - label_max.item())) if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): if soft_label == False: _, _, out = _legacy_C_ops.softmax_with_cross_entropy(