Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trying to add support for GPT2 as decoder in EncoderDecoder model #4483

Open
dimi1357 opened this issue May 20, 2020 · 31 comments
Open

Trying to add support for GPT2 as decoder in EncoderDecoder model #4483

dimi1357 opened this issue May 20, 2020 · 31 comments

Comments

@dimi1357
Copy link

dimi1357 commented May 20, 2020

🚀 Feature request

Hi,
I am trying to add the option of using GPT2 as the decoder in the EncoderDecoder model, which only support

Motivation

For a generation problem, it usually better to use GPT2 as the decoder, over BERT.

Your contribution

I've made the following changes in modeling_gpt2.py file:

  • Added crossattention layer if the model is a decoder, to the Block class:
class Block(nn.Module):
    def __init__(self, n_ctx, config, scale=False):
        super().__init__()
        nx = config.n_embd
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.attn = Attention(nx, n_ctx, config, scale)
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.mlp = MLP(4 * nx, config)
        self.is_decoder = config.is_decoder
        if self.is_decoder:
            self.crossattention = Attention(nx, n_ctx, config, scale)
...
    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, encoder_hidden_states=None,
                encoder_attention_mask=None):
        output_attn = self.attn(
            self.ln_1(x),
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
        )
        a = output_attn[0]  # output_attn: a, present, (attentions)
        outputs = []
        if self.is_decoder and encoder_hidden_states is not None:
            cross_attention_outputs = self.crossattention(
                a, layer_past, attention_mask, head_mask, encoder_hidden_states=encoder_hidden_states,
                                            encoder_attention_mask=encoder_attention_mask
            )
            a = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights

        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m

        outputs = [x] + output_attn[1:] + outputs

        return outputs  # x, present, (attentions)
  • Added 3 Linear layers instead of the Conv1d layer:
class Attention(nn.Module):
    def __init__(self, nx, n_ctx, config, scale=False):
...
        # self.c_attn = Conv1D(n_state * 3, nx)
        self.query = nn.Linear(n_state, nx)
        self.key = nn.Linear(n_state, nx)
        self.value = nn.Linear(n_state, nx)
...
  • Added encoder_attention_mask and encoder_hidden_states to the forward function of the Attention class, and using them for the key and the value if they are provided:
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, encoder_hidden_states=None,
                encoder_attention_mask=None):
        query = self.query(x)
        if encoder_hidden_states is not None:
            key = self.key(encoder_hidden_states)
            value = self.value(encoder_hidden_states)
            attention_mask = encoder_attention_mask
        else:
            key = self.key(x)
            value = self.value(x)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
...
  • Added the encoder_attention_mask and encoder_hidden_states arguments to the GPT2Model forward function, and processed encoder_attention_mask same as attention_mask:
class GPT2Model(GPT2PreTrainedModel):
...
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        use_cache=True,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
...
        # Encoder attention mask. (same action as for regular attention mask)
        if encoder_attention_mask is not None:
            assert batch_size > 0, "batch_size has to be defined and > 0"
            encoder_attention_mask = encoder_attention_mask.view(batch_size, -1)
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1).unsqueeze(2)
            encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
            encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
...
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)

            outputs = block(
                hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                use_cache=use_cache,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
            )
...
  • Added the encoder_attention_mask and encoder_hidden_states arguments to the GPT2LMHeadModelforward function, as well as lm_lables and masked_lm_labels for EncoderDecoder model compatibility (probably it's better to use GPT2DoubleHeadsModel):
class GPT2LMHeadModel(GPT2PreTrainedModel):
...
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        use_cache=True,
        lm_labels=None,
        masked_lm_labels=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
...
        if lm_labels is not None:
            if labels is not None:
                raise ValueError("You cannot specify both labels and lm_labels at the same time")
            labels = lm_labels

        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
...

My biggest concern is with the second bullet, and I wanted to ask you if this implementation seems right (for now it's look like I am able to train and test an EncoderDecoder with BERT2GPT architecture).
Of course that if needed, I can provide the full code to all of my changes, but all of my changes is listed above.
Most (if not all) of the code I've add is adapted from huggingface modeling_bert.pyfile, so all of the credit goes to them.

Thanks

@sam-writer
Copy link
Contributor

@dimi1357 out of curiosity, what does training this look like?

@dimi1357
Copy link
Author

dimi1357 commented May 24, 2020

@dimi1357 out of curiosity, what does training this look like?

This is my training loop:

x, encoder_attention_mask, y, decoder_attention_mask, _ = batch
x = x.to(self.device)
y = y.to(self.device)
encoder_attention_mask = encoder_attention_mask.to(self.device)
decoder_attention_mask = decoder_attention_mask.to(self.device)
model_kwargs = {
    "attention_mask": encoder_attention_mask,
    "decoder_attention_mask": decoder_attention_mask,
    "lm_labels": y
}
self.optimizer.zero_grad()
outputs = self.model(input_ids=x, decoder_input_ids=y, **model_kwargs)
loss = outputs[0]
loss.backward()
self.optimizer.step()
if self.scheduler is not None:
    self.scheduler.step()

and I create the model this way:

config_decoder = AutoConfig.from_pretrained(decoder_model_name, is_decoder=True)
config_encoder = AutoConfig.from_pretrained(encoder_model_name, is_decoder=False)
config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
res_model = EncoderDecoderModel(config=config)

@manzar96
Copy link

manzar96 commented May 24, 2020

@dimi1357 Did you finally make it work? Can you provide me the "full changes" in some way? I am also interested in using the GPT2 model as decoder.

@patrickvonplaten
Copy link
Contributor

Thanks for the Feature request and the in-detail code! I will think a bit more about how to implement this and get back to you!

@dimi1357
Copy link
Author

Thanks for the Feature request and the in-detail code! I will think a bit more about how to implement this and get back to you!

I forgot to add the change I've made to Block class forward function (I've also edited the issue):

    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, encoder_hidden_states=None,
                encoder_attention_mask=None):
        output_attn = self.attn(
            self.ln_1(x),
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
        )
        a = output_attn[0]  # output_attn: a, present, (attentions)
        outputs = []
        if self.is_decoder and encoder_hidden_states is not None:
            cross_attention_outputs = self.crossattention(
                a, layer_past, attention_mask, head_mask, encoder_hidden_states=encoder_hidden_states,
                                            encoder_attention_mask=encoder_attention_mask
            )
            a = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights

        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m

        outputs = [x] + output_attn[1:] + outputs

        return outputs  # x, present, (attentions)

@dimi1357
Copy link
Author

@dimi1357 Did you finally make it work? Can you provide me the "full changes" in some way? I am also interested in using the GPT2 model as decoder.

You can add the code above to where you've installed the transformers package, but I'm still not sure that this implementation is correct, so I suggest you wait for an update from huggingface team if this is okay.

@patrickvonplaten
Copy link
Contributor

Hey @dimi1357 . So I think the Encoder Decoder roadmap is as follows:

  • In ~2 weeks, we will open-source a clean notebook showing how a Bert2Bert model can be fine-tuned
  • After that, we will take a deeper look into hooking GPT2 into the EncoderDecoder framework.

I will keep your code sample here in mind for this :-)

@stale
Copy link

stale bot commented Jul 29, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jul 29, 2020
@dimi1357
Copy link
Author

Hey @dimi1357 . So I think the Encoder Decoder roadmap is as follows:

  • In ~2 weeks, we will open-source a clean notebook showing how a Bert2Bert model can be fine-tuned
  • After that, we will take a deeper look into hooking GPT2 into the EncoderDecoder framework.

I will keep your code sample here in mind for this :-)

Hi,
Is there any updates regarding to BERT2GPT implementation.
Thanks!

@stale stale bot removed the wontfix label Jul 29, 2020
@patrickvonplaten
Copy link
Contributor

Hey, I will take a look at BERTGPT2 encoder-decoder probably on Monday next week

@iliemihai
Copy link

@patrickvonplaten Can you please share a work in progress notebook/colab, or some code. I am willing to help with tests and datasets, in order to improve the BERT2GPT2 model. Thank you :D

@patrickvonplaten
Copy link
Contributor

Will finish the PR tomorrow then it should be pretty easy to do BERT2GPT2.

@iliemihai
Copy link

iliemihai commented Aug 17, 2020

Hi @patrickvonplaten . I've used your latest commit to train BERT2GPT2 using your BERT2BERT training tutorial. It was straight forward, I only had to replace the "bert" from decoder with "gpt2". The training worked, but at inference time there was a code error in prepare_inputs_for_generation at line 299:

/transformers/modeling_encoder_decoder.py
297         # first step
298         if type(past) is tuple:
299             encoder_outputs, _ = past  <----
300         else:
301             encoder_outputs = (past,)

ValueError: too many values to unpack (expected 2)

I do not know if the model requires a different evaluation approach.

@dimi1357
Copy link
Author

Will finish the PR tomorrow then it should be pretty easy to do BERT2GPT2.

Thanks for the implementation, I'm going to test it now.

@patrickvonplaten
Copy link
Contributor

GPT2 is added and results on summariation look promising. Check out this model (Bert2GPT2 trained on CNN/Daily Mail) including train and eval script: https://huggingface.co/patrickvonplaten/bert2gpt2-cnn_dailymail-fp16 .

@AmbiTyga
Copy link
Contributor

AmbiTyga commented Sep 4, 2020

Hi @patrickvonplaten, I used this model card to train on my custom dataset, but again the TypeError is been thrownback that forward() got an unexpected keyword argument 'encoder_hidden_states'
here is my code

import nlp
import logging
from transformers import BertTokenizer, GPT2Tokenizer, EncoderDecoderModel, Trainer, TrainingArguments

logging.basicConfig(level=logging.INFO)

model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "gpt2")
# cache is currently not supported by EncoderDecoder framework
model.decoder.config.use_cache = False
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# CLS token will work as BOS token
bert_tokenizer.bos_token = bert_tokenizer.cls_token

# SEP token will work as EOS token
bert_tokenizer.eos_token = bert_tokenizer.sep_token


# make sure GPT2 appends EOS in begin and end
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
    outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
    return outputs


GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id
gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token


# set decoding params
model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id
model.config.eos_token_id = gpt2_tokenizer.eos_token_id
model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
model.early_stopping = True
model.length_penalty = 2.0
model.num_beams = 4

# load train and validation data
train_dataset = nlp.load_dataset('csv', data_files='data.csv',split='train[:80%]')
val_dataset = nlp.load_dataset('csv', data_files='data.csv',split='train[80%:]')

# load rouge for validation
rouge = nlp.load_metric("rouge", experiment_id=1)

encoder_length = 512
decoder_length = 128
batch_size = 16


# map data correctly
def map_to_encoder_decoder_inputs(batch):    # Tokenizer will automatically set [BOS] <text> [EOS] 
    # use bert tokenizer here for encoder
    inputs = bert_tokenizer.encode_plus(batch["Patient"], padding="max_length", truncation=True, max_length=encoder_length)
    # force summarization <= 128
    outputs = gpt2_tokenizer.encode_plus(batch["Doctor"], padding="max_length", truncation=True, max_length=decoder_length)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_input_ids"] = outputs.input_ids
    batch["labels"] = outputs.input_ids.copy()
    batch["decoder_attention_mask"] = outputs.attention_mask

    # complicated list comprehension here because pad_token_id alone is not good enough to know whether label should be excluded or not
    batch["labels"] = [
        [-100 if mask == 0 else token for mask, token in mask_and_tokens] for mask_and_tokens in [zip(masks, labels) for masks, labels in zip(batch["decoder_attention_mask"], batch["labels"])]
    ]

    assert all([len(x) == encoder_length for x in inputs.input_ids])
    assert all([len(x) == decoder_length for x in outputs.input_ids])

    return batch


def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = gpt2_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = gpt2_tokenizer.eos_token_id
    label_str = gpt2_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }


# make train dataset ready
train_dataset = train_dataset.map(
    map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["Patient", "Doctor"],
)
train_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

# same for validation dataset
val_dataset = val_dataset.map(
    map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["Patient", "Doctor"],
)
val_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

# set training arguments - these params are not really tuned, feel free to change
training_args = TrainingArguments(
    output_dir="./ambi",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    evaluate_during_training=True,
    do_train=True,
    do_eval=True,
    logging_steps=1000,
    save_steps=1000,
    eval_steps=1000,
    overwrite_output_dir=True,
    warmup_steps=2000,
    save_total_limit=10,
    fp16=True,
)

# instantiate trainer
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

# start training
trainer.train()

If you can see it carefully you can find that an argument is missing in TrainingArguments module, I always get an error that why predict_from_generate is passed, I tried finding that attribute in training_args.py, but it seems there is no such attribute available in it. Please clarify which version are you using, If it is above 2.11 then please clarify why my the above code is throwing this error.

@patrickvonplaten
Copy link
Contributor

You need to switch to this branch: https://github.com/huggingface/transformers/tree/more_general_trainer_metric to make the training work. I am trying to integrate this branch into master soon :-)

@AmbiTyga
Copy link
Contributor

AmbiTyga commented Sep 4, 2020

Thanks for letting me know.

@One-punch24
Copy link

Sorry to ask a question after a long period of time :-). I am still not very clear about the effect of encoder attention mask in GPT2.

I understand that it is used only in the decoder of Encoder-Decoder model to make some change to the cross attention weights. Also, I notice the operation defined in the modelling_gpt2.py:
attention_mask = encoder_attention_mask
...
w=w+attention_mask

However, I am confused why we need this encoder attention mask. Is that also because the decoder can not see the whole sequence?

Thanks for help :-)

@Eurus-W
Copy link

Eurus-W commented Mar 12, 2022

Hi @patrickvonplaten, I used this model card to train on my custom dataset, but again the TypeError is been thrownback that forward() got an unexpected keyword argument 'encoder_hidden_states' here is my code

import nlp
import logging
from transformers import BertTokenizer, GPT2Tokenizer, EncoderDecoderModel, Trainer, TrainingArguments

logging.basicConfig(level=logging.INFO)

model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "gpt2")
# cache is currently not supported by EncoderDecoder framework
model.decoder.config.use_cache = False
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# CLS token will work as BOS token
bert_tokenizer.bos_token = bert_tokenizer.cls_token

# SEP token will work as EOS token
bert_tokenizer.eos_token = bert_tokenizer.sep_token


# make sure GPT2 appends EOS in begin and end
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
    outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
    return outputs


GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id
gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token


# set decoding params
model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id
model.config.eos_token_id = gpt2_tokenizer.eos_token_id
model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
model.early_stopping = True
model.length_penalty = 2.0
model.num_beams = 4

# load train and validation data
train_dataset = nlp.load_dataset('csv', data_files='data.csv',split='train[:80%]')
val_dataset = nlp.load_dataset('csv', data_files='data.csv',split='train[80%:]')

# load rouge for validation
rouge = nlp.load_metric("rouge", experiment_id=1)

encoder_length = 512
decoder_length = 128
batch_size = 16


# map data correctly
def map_to_encoder_decoder_inputs(batch):    # Tokenizer will automatically set [BOS] <text> [EOS] 
    # use bert tokenizer here for encoder
    inputs = bert_tokenizer.encode_plus(batch["Patient"], padding="max_length", truncation=True, max_length=encoder_length)
    # force summarization <= 128
    outputs = gpt2_tokenizer.encode_plus(batch["Doctor"], padding="max_length", truncation=True, max_length=decoder_length)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_input_ids"] = outputs.input_ids
    batch["labels"] = outputs.input_ids.copy()
    batch["decoder_attention_mask"] = outputs.attention_mask

    # complicated list comprehension here because pad_token_id alone is not good enough to know whether label should be excluded or not
    batch["labels"] = [
        [-100 if mask == 0 else token for mask, token in mask_and_tokens] for mask_and_tokens in [zip(masks, labels) for masks, labels in zip(batch["decoder_attention_mask"], batch["labels"])]
    ]

    assert all([len(x) == encoder_length for x in inputs.input_ids])
    assert all([len(x) == decoder_length for x in outputs.input_ids])

    return batch


def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = gpt2_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = gpt2_tokenizer.eos_token_id
    label_str = gpt2_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }


# make train dataset ready
train_dataset = train_dataset.map(
    map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["Patient", "Doctor"],
)
train_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

# same for validation dataset
val_dataset = val_dataset.map(
    map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["Patient", "Doctor"],
)
val_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

# set training arguments - these params are not really tuned, feel free to change
training_args = TrainingArguments(
    output_dir="./ambi",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    evaluate_during_training=True,
    do_train=True,
    do_eval=True,
    logging_steps=1000,
    save_steps=1000,
    eval_steps=1000,
    overwrite_output_dir=True,
    warmup_steps=2000,
    save_total_limit=10,
    fp16=True,
)

# instantiate trainer
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

# start training
trainer.train()

If you can see it carefully you can find that an argument is missing in TrainingArguments module, I always get an error that why predict_from_generate is passed, I tried finding that attribute in training_args.py, but it seems there is no such attribute available in it. Please clarify which version are you using, If it is above 2.11 then please clarify why my the above code is throwing this error.

@AmbiTyga @patrickvonplaten Is this error fixed? I have switched to the branch "more_general_trainer_metric." But it seems this error still exists when I am running codes in https://huggingface.co/patrickvonplaten/bert2gpt2-cnn_dailymail-fp16.

@patrickvonplaten
Copy link
Contributor

The code is a bit outdated there. You should be able to simply use the https://github.com/huggingface/transformers/tree/master/examples/pytorch/summarization example. In order to create a BERT2GPT2 checkpoint, you could a code that is similar to this one: https://huggingface.co/docs/transformers/v4.17.0/en/model_doc/encoder-decoder#transformers.EncoderDecoderModel.forward

(just replace one BERT by GPT2)

So to summarize,

  1. Create a warm-started bert-gpt2 checkpoint
  2. save checkpoint
  3. use summarization example to fine-tune the checkpoint

I'll keep this issue open for now since we should probably create a nice "How-to" guide for this

@Eurus-W
Copy link

Eurus-W commented Mar 15, 2022

The code is a bit outdated there. You should be able to simply use the https://github.com/huggingface/transformers/tree/master/examples/pytorch/summarization example. In order to create a BERT2GPT2 checkpoint, you could a code that is similar to this one: https://huggingface.co/docs/transformers/v4.17.0/en/model_doc/encoder-decoder#transformers.EncoderDecoderModel.forward

(just replace one BERT by GPT2)

So to summarize,

  1. Create a warm-started bert-gpt2 checkpoint
  2. save checkpoint
  3. use summarization example to fine-tune the checkpoint

I'll keep this issue open for now since we should probably create a nice "How-to" guide for this

Thanks for your guidance! I try this method to create and ft a bert2gpt2 model, but it seems that "tokenizer" would be a problem: I can't load a single suitable tokenizer for this model in the summarization example. So is it necessary for me to defined tokenizer1 for bert and tokenizer2 for gpt2 and then change any code that is related to "tokenizer" in order to fix this problem? @patrickvonplaten

@patrickvonplaten
Copy link
Contributor

It's fine to load two tokenizers no?

@Eurus-W
Copy link

Eurus-W commented Mar 18, 2022

Yeah,I use 2 tokenizers to replace "tokenizer" in run_summarization.py and also do some other changes, the code can work now(although I don't know whether it is right....). Here are my changes.

  1. change the resize_token_embeddings method#model.resize_token_embeddings(len(tokenizer))
    model.encoder.resize_token_embeddings(len(tokenizer1))
    model.decoder.resize_token_embeddings(len(tokenizer2))
  2. some special tokens settings according to https://huggingface.co/patrickvonplaten/bert2gpt2-cnn_dailymail-fp16
  3. facing problem like seq2seq BertGeneration model failed "ValueError: You have to specify either input_ids or inputs_embeds" #10646 (comment), and used codes in https://github.com/huggingface/transformers/blob/24e2fa1590faac894da3422daf56abf9770c9d81/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L555 line554-555 and line147-162
  4. Noticing that in bert base/large "max_position_embeddings" is 512, and default max_source_length in run_summarization.py is 1024, as a result if our input sequence length is over 512, we will get an error like Script run_mlm_no_trainer.py error #15081 (comment). So let max_source_length=512.
  5. all codes segmentations of (tokenizer->tokenizer2) in run_summarization.py(Not sure)
     # Setup the tokenizer for targets
        with tokenizer2.as_target_tokenizer():
            labels = tokenizer2(targets, max_length=max_target_length, padding=padding, truncation=True)

        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if padding == "max_length" and data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer2.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        decoded_preds = tokenizer2.batch_decode(preds, skip_special_tokens=True)
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer2.batch_decode(labels, skip_special_tokens=True)

        if trainer.is_world_process_zero():
            if training_args.predict_with_generate:
                predictions = tokenizer2.batch_decode(
                    predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
                predictions = [pred.strip() for pred in predictions]
                output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
                with open(output_prediction_file, "w") as writer:
                    writer.write("\n".join(predictions))

It's fine to load two tokenizers no?

@IdoAmit198
Copy link

Hey everyone,
Did this work go anywhere?
I need a pre-trained gpt2 model based on nn.Linear instead of Conv1D layers for research purpose, Is the implementation above merged anywhere, or there exist some other gpt2 model based on nn.Linear?

@Forpee
Copy link

Forpee commented Apr 27, 2023

Can I work on this issue as a good first issue or is there no point?

@sgugger
Copy link
Collaborator

sgugger commented Apr 27, 2023

I don't think there is any point @Forpee

@Bachstelze
Copy link

For a generation problem, it is usually better to use GPT2 as the decoder, over BERT.

Why should this be the case, if you have enough data to train the new cross-attention parameters?

The paper for the encoderDecoderModel reports for the summarization task:
sum_daily_cnn

@sameearif
Copy link

For a generation problem, it is usually better to use GPT2 as the decoder, over BERT.

Why should this be the case, if you have enough data to train the new cross-attention parameters?

The paper for the encoderDecoderModel reports for the summarization task: sum_daily_cnn

Hello, can you share the training code for Bert2GPT and Roberta2GPT please?

@Bachstelze
Copy link

For a generation problem, it is usually better to use GPT2 as the decoder, over BERT.

Why should this be the case, if you have enough data to train the new cross-attention parameters?
The paper for the encoderDecoderModel reports for the summarization task: sum_daily_cnn

Hello, can you share the training code for Bert2GPT and Roberta2GPT please?

You can just use the current implementation described in the docs:

from transformers import EncoderDecoderModel

# initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")

# saving model after fine-tuning
model.save_pretrained("./bert2gpt2")

# load fine-tuned model
model = EncoderDecoderModel.from_pretrained("./bert2gpt2")

Why do you want to do that, given the cited performance reduction?

@sameearif
Copy link

For a generation problem, it is usually better to use GPT2 as the decoder, over BERT.

Why should this be the case, if you have enough data to train the new cross-attention parameters?
The paper for the encoderDecoderModel reports for the summarization task: sum_daily_cnn

Hello, can you share the training code for Bert2GPT and Roberta2GPT please?

You can just use the current implementation described in the docs:

from transformers import EncoderDecoderModel

# initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")

# saving model after fine-tuning
model.save_pretrained("./bert2gpt2")

# load fine-tuned model
model = EncoderDecoderModel.from_pretrained("./bert2gpt2")

Why do you want to do that, given the cited performance reduction?

I am trying to train it on question generation task to compare the results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests