diff --git a/paddlenlp/transformers/luke/tokenizer.py b/paddlenlp/transformers/luke/tokenizer.py index 7653c96526c2..196dd7835d9e 100644 --- a/paddlenlp/transformers/luke/tokenizer.py +++ b/paddlenlp/transformers/luke/tokenizer.py @@ -608,7 +608,7 @@ def _convert_token_to_id_with_added_voc(self, token): return self._convert_token_to_id(token) - def add_special_tokens(self, token_list: Union[List[int], Dict]): + def add_special_tokens(self, token_list: Union[List[int], Dict], replace_additional_special_tokens: bool = True): """ Adding special tokens if you need. @@ -616,14 +616,36 @@ def add_special_tokens(self, token_list: Union[List[int], Dict]): token_list (List[int], Dict[List[int]]): The special token list you provided. If you provide a Dict, the key of the Dict must be "additional_special_tokens" and the value must be token list. + replace_additional_special_tokens (bool, optional, defaults to True): + If True, the existing list of additional special tokens will be replaced by the list provided in + `token_list`. Otherwise, `self._additional_special_tokens` is just extended. In the former + case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged + as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the + `added_tokens_encoder` and `added_tokens_decoder`. This means that the previous + `additional_special_tokens` are still added tokens, and will not be split by the model. """ if isinstance(token_list, dict): token_list = token_list["additional_special_tokens"] + + if not hasattr(self, "_additional_special_tokens"): + self._additional_special_tokens = [] + + if replace_additional_special_tokens: + self._additional_special_tokens = list(token_list) + else: + self._additional_special_tokens.extend( + [token for token in token_list if token not in self._additional_special_tokens] + ) encoder_dict = dict() decoder_dict = dict() - for token in token_list: - encoder_dict[token] = len(self.encoder.keys()) - decoder_dict[len(self.decoder.keys())] = token + current_encoder_length = len(self.encoder) + len(self.added_tokens_encoder) + current_decoder_length = len(self.decoder) + len(self.added_tokens_decoder) + + for idx, token in enumerate(token_list): + if token not in self.added_tokens_encoder: + encoder_dict[token] = current_encoder_length + idx + decoder_dict[current_decoder_length + idx] = token + self.added_tokens_encoder.update(encoder_dict) self.added_tokens_decoder.update(decoder_dict) diff --git a/paddlenlp/transformers/tokenizer_utils_base.py b/paddlenlp/transformers/tokenizer_utils_base.py index a731acaed2f6..6af5cc29e5d4 100644 --- a/paddlenlp/transformers/tokenizer_utils_base.py +++ b/paddlenlp/transformers/tokenizer_utils_base.py @@ -875,11 +875,9 @@ def add_special_tokens( to_add = [] for token in value: - if isinstance(token, str): - # for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this - token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True) - if replace_additional_special_tokens or str(token) not in self.additional_special_tokens: - to_add.append(token) + if not replace_additional_special_tokens and str(token) in self.additional_special_tokens: + continue + to_add.append(token) if replace_additional_special_tokens and len(to_add) > 0: setattr(self, key, list(to_add)) else: @@ -889,11 +887,7 @@ def add_special_tokens( else: if not isinstance(value, (str, AddedToken)): raise ValueError(f"Token {value} for key {key} should be a str or an AddedToken instance") - if isinstance(value, (str)): - # for legacy purpose we default to stripping. `False` depends on this - value = AddedToken(value, rstrip=False, lstrip=False, normalized=False, special=True) - if isinstance(value, AddedToken): - setattr(self, key, value) + setattr(self, key, value) if value not in added_tokens: added_tokens.append(value) diff --git a/tests/transformers/test_tokenizer_common.py b/tests/transformers/test_tokenizer_common.py index 7c8514eeea5e..7d78bfb09e0f 100644 --- a/tests/transformers/test_tokenizer_common.py +++ b/tests/transformers/test_tokenizer_common.py @@ -1155,6 +1155,7 @@ def test_maximum_encoding_length_pair_input(self): # encoded_masked[mask_loc] = encoded_1[mask_loc] # self.assertEqual(encoded_masked, encoded_1) + def test_special_token_addition(self): for tokenizer, pretrained_name, kwargs in self.tokenizers_list: with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): @@ -1180,7 +1181,6 @@ def test_special_token_addition(self): ) self.assertEqual(tokenizer_2.additional_special_tokens, ["", "", ""]) - def test_special_tokens_mask(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: