Skip to content

Commit

Permalink
fix & add special handling for LukeTokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lvdongyi committed Sep 17, 2024
1 parent 50c8631 commit 20f9c6c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 15 deletions.
30 changes: 26 additions & 4 deletions paddlenlp/transformers/luke/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,22 +608,44 @@ def _convert_token_to_id_with_added_voc(self, token):

return self._convert_token_to_id(token)

def add_special_tokens(self, token_list: Union[List[int], Dict]):
def add_special_tokens(self, token_list: Union[List[int], Dict], replace_additional_special_tokens: bool = True):
"""
Adding special tokens if you need.
Args:
token_list (List[int], Dict[List[int]]):
The special token list you provided. If you provide a Dict, the key of the Dict must
be "additional_special_tokens" and the value must be token list.
replace_additional_special_tokens (bool, optional, defaults to True):
If True, the existing list of additional special tokens will be replaced by the list provided in
`token_list`. Otherwise, `self._additional_special_tokens` is just extended. In the former
case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
`added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
`additional_special_tokens` are still added tokens, and will not be split by the model.
"""
if isinstance(token_list, dict):
token_list = token_list["additional_special_tokens"]

if not hasattr(self, "_additional_special_tokens"):
self._additional_special_tokens = []

Check warning on line 631 in paddlenlp/transformers/luke/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/luke/tokenizer.py#L631

Added line #L631 was not covered by tests

if replace_additional_special_tokens:
self._additional_special_tokens = list(token_list)
else:
self._additional_special_tokens.extend(
[token for token in token_list if token not in self._additional_special_tokens]
)
encoder_dict = dict()
decoder_dict = dict()
for token in token_list:
encoder_dict[token] = len(self.encoder.keys())
decoder_dict[len(self.decoder.keys())] = token
current_encoder_length = len(self.encoder) + len(self.added_tokens_encoder)
current_decoder_length = len(self.decoder) + len(self.added_tokens_decoder)

for idx, token in enumerate(token_list):
if token not in self.added_tokens_encoder:
encoder_dict[token] = current_encoder_length + idx
decoder_dict[current_decoder_length + idx] = token

self.added_tokens_encoder.update(encoder_dict)
self.added_tokens_decoder.update(decoder_dict)

Expand Down
14 changes: 4 additions & 10 deletions paddlenlp/transformers/tokenizer_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,11 +875,9 @@ def add_special_tokens(

to_add = []
for token in value:
if isinstance(token, str):
# for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this
token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True)
if replace_additional_special_tokens or str(token) not in self.additional_special_tokens:
to_add.append(token)
if not replace_additional_special_tokens and str(token) in self.additional_special_tokens:
continue

Check warning on line 879 in paddlenlp/transformers/tokenizer_utils_base.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tokenizer_utils_base.py#L879

Added line #L879 was not covered by tests
to_add.append(token)
if replace_additional_special_tokens and len(to_add) > 0:
setattr(self, key, list(to_add))
else:
Expand All @@ -889,11 +887,7 @@ def add_special_tokens(
else:
if not isinstance(value, (str, AddedToken)):
raise ValueError(f"Token {value} for key {key} should be a str or an AddedToken instance")

Check warning on line 889 in paddlenlp/transformers/tokenizer_utils_base.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/tokenizer_utils_base.py#L889

Added line #L889 was not covered by tests
if isinstance(value, (str)):
# for legacy purpose we default to stripping. `False` depends on this
value = AddedToken(value, rstrip=False, lstrip=False, normalized=False, special=True)
if isinstance(value, AddedToken):
setattr(self, key, value)
setattr(self, key, value)
if value not in added_tokens:
added_tokens.append(value)

Expand Down
2 changes: 1 addition & 1 deletion tests/transformers/test_tokenizer_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,7 @@ def test_maximum_encoding_length_pair_input(self):
# encoded_masked[mask_loc] = encoded_1[mask_loc]

# self.assertEqual(encoded_masked, encoded_1)

def test_special_token_addition(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
Expand All @@ -1180,7 +1181,6 @@ def test_special_token_addition(self):
)
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])


def test_special_tokens_mask(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
Expand Down

0 comments on commit 20f9c6c

Please sign in to comment.