Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Warning when passing padded input ids but no attention mask #17444

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented May 26, 2022

What does this PR do?

One of the most common mistake users make in Transformers IMO is that input_ids are padded, but no attention_mask is provided (we see many examples of this). As discussed multiple times, we don't want to infer the attention_mask automatically as this creates a lot of unmaintainable, "not-possible-to-deal-with" complexity.

A while ago, we discussed to throw a warning in this case, making sure it's done only once to not spam the user when calling the model multiple times. I'm not sure we found a good conclusion, but IMO it's important that we warn the user as too users (IMO) think the attention_mask is inferred from the padding tokens. This PR is tries to solve this and shows how it'd be implemented for just BERT. We would have to implement it for all other models then as well. Would very much like to hear your opinion here @sgugger @LysandreJik @patil-suraj . Note that this PR will touch a lot of important functions / files, so it'd be very important to make the warning as clear as possible.
I do however have a strong conviction that we should display such a warning.

No the warning function can display the following warning messages for a toy BERT example of passing just three input ids.

Possible warning messages:

  1. Pad token present, no attention mask, eos, bos, sep all different from pad (that's VERY likely an error IMO):

Displayed warning:

The input IDs tensor([[0, 1, 1]]) contains the `pad_token_id` 0, but NO `attention_mask` is passed.
Padding the input IDs without passing an `attention_mask` leads to unexpected, possibly incorrect outputs.
  1. Pad token present, no attention mask, eos or bos or sep same as pad:

Displayed warning:

The input IDs tensor([[0, 1, 1]]) contains the `pad_token_id` 0, but NO `attention_mask` is passed.
We strongly recommend passing an `attention_mask` to avoid possibly incorrectly computing the attention weights. 
You can ignore this warning, if your `pad_token_id` 0 is identical to your `sep_token_id` 0 AND your input is NOT padded.
  1. Pad token present, no attention mask, two or more of eos, bos, sep identical to pad (don't think this exists actually):

Displayed warning:

The input IDs tensor([[0, 1, 1]]) contains the `pad_token_id` 0, but NO `attention_mask` is passed.
We strongly recommend passing an `attention_mask` to avoid possibly incorrectly computing the attention weights. 
You can ignore this warning, if your `pad_token_id` 0 is identical to your `bos_token_id` 0 AND your input is NOT padded.
We strongly recommend passing an `attention_mask` to avoid possibly incorrectly computing the attention weights. 
You can ignore this warning, if your `pad_token_id` 0 is identical to your `sep_token_id` 0 AND your input is NOT padded.
  1. Otherwise no warning.

Also note that the warning only appears at the first forward call.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@patrickvonplaten patrickvonplaten changed the title First draft Warning when passing padded input ids but no attention mask May 26, 2022
@patrickvonplaten patrickvonplaten changed the title Warning when passing padded input ids but no attention mask [WIP] Warning when passing padded input ids but no attention mask May 26, 2022
@patrickvonplaten
Copy link
Contributor Author

Relevant issues:
#4083
#278
#16136

f" {self.config.pad_token_id} is identical to your `bos_token_id` {self.config.bos_token_id} AND"
" your input is NOT padded."
)
if is_pad_token_equal_to_eos_token:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use elif here and below ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be the case that pad_token == eos_token == bos_token -> would like to append the string then

Comment on lines +1014 to +1015
if not hasattr(self, "warnings_issued"):
self.warnings_issued = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we try to avoid adding new instance attribute in __init__?

f" {self.config.pad_token_id} is identical to your `sep_token_id` {self.config.sep_token_id} AND"
" your input is NOT padded."
)
if not (is_pad_token_equal_to_bos_token or is_pad_token_equal_to_eos_token):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lack is_pad_token_equal_to_sep_token here, I guess?

@LysandreJik
Copy link
Member

I think the way you implemented it is clean and adds nice warnings. I agree with the idea behind it, and the better warnings we send, the better the models will perform for users.

I think handling it like it is done here based off of configuration attribute is not going to work very well across models, however. I feel like having the method be configurable by passing optional bos/eos tokens would likely make the method more versatile to the models which do not conform to the default approach.

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented May 31, 2022

I think handling it like it is done here based off of configuration attribute is not going to work very well across models, however. I feel like having the method be configurable by passing optional bos/eos tokens would likely make the method more versatile to the models which do not conform to the default approach.

Hmm, don't really agree here. Note that pad_token_id, bos_token_id, eos_token_id, sep_token_id must be present in every model's config since it's in configuration_utils.py.
Also we never pass any of the above attributes through the forward method, so one would only ever pass self.config.pad_token_id to the method. Wdyt @LysandreJik ? Also very curious to hear @sgugger's opinion here

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding those warnings (mainly the last one). I'd like to add more tests for the warnings when the pad token is the same as the eos/bos/sep tokens to avoid scaring a user for nothing (users are always scared of warnings) and it shouldn't hurt performance since they would only be run once.

As for @LysandreJik comment, I must admit I don't understand what you're suggesting Lysandre, since we only have those pad/eos/bos/sep token IDs from the config of the mdoel inside the forward.

Comment on lines +1046 to +1051
warn_string += (
"\nWe strongly recommend passing an `attention_mask` to avoid possibly incorrectly computing the"
" attention weights. \nYou can ignore this warning, if your `pad_token_id`"
f" {self.config.pad_token_id} is identical to your `bos_token_id` {self.config.bos_token_id} AND"
" your input is NOT padded."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe here let's check if the pad token ID is only used once per input at the beginning before throwing the warning?

Comment on lines +1053 to +1058
warn_string += (
"\nWe strongly recommend passing an `attention_mask` to avoid possibly incorrectly computing the"
" attention weights. \nYou can ignore this warning, if your `pad_token_id`"
f" {self.config.pad_token_id} is identical to your `eos_token_id` {self.config.eos_token_id} AND"
" your input is NOT padded."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@@ -1006,6 +1006,72 @@ def get_input_embeddings(self) -> nn.Module:
else:
raise NotImplementedError

def warn_if_pad_token_in_input_ids_no_attention_mask(self, input_ids, attention_mask):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not push descriptive names too far ;-) I think padding_attention_mask_warning is more than enough!

@LysandreJik
Copy link
Member

Sounds good, I'm likely worrying for nothing then. Good for me like this, very easy to add kwargs afterwards anyway!

@huggingface huggingface deleted a comment from github-actions bot Jun 27, 2022
@huggingface huggingface deleted a comment from github-actions bot Jul 22, 2022
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Aug 25, 2022
@LysandreJik LysandreJik reopened this Aug 30, 2022
@LysandreJik
Copy link
Member

I think this would be an impactful addition! @ydshieh, would you be interested in continuing this PR?

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 30, 2022

I think this would be an impactful addition! @ydshieh, would you be interested in continuing this PR?

Sure. I will take a look and see if there is anything blocking.

@huggingface huggingface deleted a comment from github-actions bot Sep 23, 2022
@ydshieh ydshieh self-assigned this Sep 23, 2022
@huggingface huggingface deleted a comment from github-actions bot Oct 18, 2022
@ydshieh ydshieh added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Oct 18, 2022
@ydshieh
Copy link
Collaborator

ydshieh commented Feb 3, 2023

You can search elif input_ids is not None: that is in the base model classes like BertModel (already done by @patrickvonplaten), GPT2Model etc.

You don't need to replace all of them - it would be super nice already for a few of the most used modes 🚀 Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants