From a6b37f915235fe75a28f971e06a867a9801f9fa7 Mon Sep 17 00:00:00 2001 From: gongel Date: Wed, 11 Jan 2023 08:58:29 +0000 Subject: [PATCH] fix windows bug --- paddlenlp/transformers/pegasus/modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddlenlp/transformers/pegasus/modeling.py b/paddlenlp/transformers/pegasus/modeling.py index b7511cfcf1ac..07674e9ed87a 100644 --- a/paddlenlp/transformers/pegasus/modeling.py +++ b/paddlenlp/transformers/pegasus/modeling.py @@ -42,7 +42,9 @@ def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id): if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") - shifted_input_ids = paddle.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + shifted_input_ids = paddle.where( + shifted_input_ids == -100, paddle.full_like(shifted_input_ids, pad_token_id), shifted_input_ids + ) return shifted_input_ids