Skip to content

Commit

Permalink
Merge pull request #77 from shang1017/fix_query_seq_slots
Browse files Browse the repository at this point in the history
fix fused query seq tensor compare case
  • Loading branch information
qingshui committed Jan 24, 2024
2 parents 61a9d60 + 7acd059 commit 5d50595
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/fused_seq_tensor_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ class FusedSeqTensorOp : public framework::OperatorWithKernel {
ad_slot_num, 0,
platform::errors::InvalidArgument(
"ad_slot_num [%ld] <= 0", ad_slot_num));
PADDLE_ENFORCE_LT(
PADDLE_ENFORCE_LE(
ad_slot_offset, slot_num - 1,
platform::errors::InvalidArgument(
"ad_slot_num [%ld] > slot_num - 1 [%ld]", ad_slot_offset, slot_num));
"ad_slot_num [%ld] > slot_num - 1 [%ld]", ad_slot_offset, slot_num));
PADDLE_ENFORCE_GE(
ad_slot_offset, 0,
platform::errors::InvalidArgument(
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/fused_seq_tensor_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ __device__ void warpReduce(volatile T* cache, int tid) {

#define THREAD_PER_BLOCK 128
template <typename T>
__global__ void reduce_sum_max_length(const T* input, // 1
T* mask_output, // mask
__global__ void reduce_sum_max_length(const T* input,
T* mask_output,
const size_t batch_count,
const size_t ins_num,
const size_t slot_num,
Expand Down

0 comments on commit 5d50595

Please sign in to comment.