Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong/inconsistent behaviour in EncoderDecoderModel and generate method #15479

Closed
LarsHill opened this issue Feb 2, 2022 · 7 comments
Closed

Comments

@LarsHill
Copy link

LarsHill commented Feb 2, 2022

Hi guys,

When creating my own EncoderDecoder Abstraction to have more flexibility initializing custom huggingface based models as encoder and decoder, I noticed a couple of issues that imho should be fixed/changed.

I think this is most relevant to @patrickvonplaten since he seems to be in charge of EncoderDecoderModel in modeling_encoder_decoder.py and generation_utils.py.
I tag @LysandreJik as well since I detected some other potential improvement related to BERT (maybe also other models if the behaviour is the same).

  1. Problem: When using EncoderDecoderModel the underlying encoder and decoder could have different tokenizers and thus different pad_token_id. That means self.config.encoder.pad_token_id and self.config.decoder.pad_token_id might be different. When generating with an encoder_decoder_instance and no attention_mask is provided to the generate function, the attention mask is internally created. This happens in _prepare_attention_mask_for_generation() in generation_utils.py. However, this function does not distinguish the encoder-decoder vs decoder only case. Hence, it uses the pad_token_id that is registered in self.config.pad_token_id. This can cause a problem if self.config.pad_token_id is not equal to self.config.encoder.pad_token_id.
    Proposed Solution: Imho _prepare_attention_mask_for_generation() should check for self.config.is_encoder_decoder and if true tha padding token should be taken from self.config.encoder.pad_token_id instead of self.config.pad_token_id.

  2. Problem: The decoder attention mask is created/updated on each generation step in generate() by calling prepare_inputs_for_generation() which is implemented in the corresponding model instance, e.g. encoder_decoder, bert, etc. However, BERT implements this function to simply create an all-ones mask that mimics the shape of the current input_ids, irrespective of previously predicted ids. Assuming that at some point in the generation process a pad_token_id is predicted, the attention mask update should take this into account and place a 0 at that position in the mask.
    Proposed Solution: All models that implement prepare_inputs_for_generation() should imho take their corresponding pad_token_id in input_ids into account when updating the attention_mask between generation steps.

  3. Problem : The attention_mask creation in e.g. BERT and in generate() is not aligned if the user does not provide a mask himself. BERT (and maybe other models) simply create a mask of all-ones (same shape as input_ids). As described in 1. generate() takes the pad_token_id into account and creates a proper mask based on input_ids. At the moment I don't need to provide the mask to generate() but I have to provide the mask to a simple forward() during training because if it is not provided an all-ones mask is created. I feel this should be aligned -> Model creates the correct mask internally if not provided.
    Proposed Solution: Imho each model should create the correct attention_mask if the user does not provide any. The pad_token_id is known to the model, so implementing this should be no problem.

Some feedback about these thoughts would be great. Maybe I missed something in the code base and the issues are not relevant afterall.

Thanks for your work and I'm looking forward hearing from you.

Lars

@github-actions
Copy link

github-actions bot commented Mar 4, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@patrickvonplaten
Copy link
Contributor

Hey @LarsHill,

Thanks a lot for the feedback and I'm very sorry to be so late to answer here.

  1. Very good point! It's a tricky thing to solve because:
  • We don't want to much model-specific code in generate() and this use-case is quite edge-casy, model-specific
  • On the other hand, we also don't want silent errors (they are the worst to debug). Your solution makes a lot of sense! The problem is however that most encoder-decoder models, like T5, BART don't have a config.encoder variable so we quickly arrive at 2, 3 if statements.
    I think for now the best we can do here is to write some nice docs that explain how to use the Encoder-Decoder architecture. E.g. if we would put a warning on the EncoderDecoderModel card and write a nice How-to-guide, I'm quite sure that people won't run into this issue too much. Also it's not really a bug in the code, but something the user should be made aware of IMO.
    Ideally, the user passes an attention_mask in which case there is no need to have the padding token correctly defined.
  1. Also a good point, but this behavior really should not occur. The only time there could be padding token ids in the decoder_input_ids tensor is when the user passes those her/himself and that's quite an edge-case for encoder-decoder models. I.E. passing both input_ids (like an article to summarize) and a prompt that should be used to start the summarization with is still quite an edge case for me. In this case, the user should (again) also pass the decoder_attention_mask so that this problem would be solved correctly IMO. Note that no model should generate a pad_token_id -> this would be a modeling bug and then we really can't expect the model to generate anything useful at all.

3.Good point. I agree that we should probably align the methods, but it's quite difficult now because of possible backward breaking changes. We've more or less decided to not automatically create the attention_mask if a padding token is provided because:

  • some models don't have a padding token id like GPT2, what do we do then?
  • it might be possible that a user would want to attend to a padding token and by force creating the padding token this use case is not possible anymore (think for QA this is the case).

However, I do agree that in 95% of the cases the padding token should be masked, so it does make sense to create a warning in every model if no attention_mask is provided but the input_ids contain the padding token.

@patrickvonplaten
Copy link
Contributor

Also cc #4483 (comment)

@patrickvonplaten
Copy link
Contributor

@LarsHill - would you maybe be interested in tackling this issue by improving the docs a bit: #16135 ? :-)

@patrickvonplaten
Copy link
Contributor

@LarsHill
Copy link
Author

LarsHill commented Apr 1, 2022

@LarsHill - would you maybe be interested in tackling this issue by improving the docs a bit: #16135 ? :-)

Hi,
First of all, thanks for the extensive reply! I agree, that most of my concerns could be tackled by improving the documentation. After all, it is not code breaking since there are work arounds. It is just, that I ran into some of the mentioned problems myself and had to deeply check the code base to understand what was going on.

Regarding, contributing to the documentation, I cannot promise to get on it any time soon, since I'm quite occupied with project and research work at the moment. But I keep this issue in mind and get back to it. If noone else contributed in the meantime I'll take a shot.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants