diff --git a/paddlenlp/transformers/pegasus/modeling.py b/paddlenlp/transformers/pegasus/modeling.py index b7511cfcf1ac..07674e9ed87a 100644 --- a/paddlenlp/transformers/pegasus/modeling.py +++ b/paddlenlp/transformers/pegasus/modeling.py @@ -42,7 +42,9 @@ def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id): if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") - shifted_input_ids = paddle.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + shifted_input_ids = paddle.where( + shifted_input_ids == -100, paddle.full_like(shifted_input_ids, pad_token_id), shifted_input_ids + ) return shifted_input_ids