From 447c0146f509c2e5301e573556387ff0cadcc956 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AA=91=E9=A9=AC=E5=B0=8F=E7=8C=AB?= <1435130236@qq.com> Date: Wed, 7 Sep 2022 17:59:09 +0800 Subject: [PATCH] [ModelingOutput]add more output for skep model (#3146) * update return_dict/label in skep model * complete skep add-more-output * refactor simple code Co-authored-by: Zhong Hui Co-authored-by: Guo Sheng Co-authored-by: liu zhengxi <380185688@qq.com> --- paddlenlp/transformers/skep/modeling.py | 227 +++++++++++++++++++---- tests/transformers/skep/test_modeling.py | 145 ++++++++++----- 2 files changed, 290 insertions(+), 82 deletions(-) diff --git a/paddlenlp/transformers/skep/modeling.py b/paddlenlp/transformers/skep/modeling.py index a65da0af5acc..9b1ddd71e5e2 100644 --- a/paddlenlp/transformers/skep/modeling.py +++ b/paddlenlp/transformers/skep/modeling.py @@ -25,6 +25,15 @@ else: from paddlenlp.layers.crf import ViterbiDecoder +from ..model_outputs import ( + BaseModelOutputWithPoolingAndCrossAttentions, + SequenceClassifierOutput, + TokenClassifierOutput, + QuestionAnsweringModelOutput, + MultipleChoiceModelOutput, + MaskedLMOutput, + CausalLMOutputWithCrossAttentions, +) from .. import PretrainedModel, register_base_model __all__ = [ @@ -284,7 +293,10 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, - attention_mask=None): + attention_mask=None, + output_hidden_states=False, + output_attentions=False, + return_dict=False): r""" The SkepModel forward method, overrides the `__call__()` special method. @@ -319,9 +331,23 @@ def forward(self, For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], [batch_size, num_attention_heads, sequence_length, sequence_length]. Defaults to `None`, which means nothing needed to be prevented attention to. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output + will be a tuple of tensors. Defaults to `False`. Returns: - tuple: Returns tuple (`sequence_output`, `pooled_output`). + An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions` if + `return_dict=True`. Otherwise it returns a tuple of tensors corresponding + to ordered and not None (depending on the input arguments) fields of + :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`. + + if the reuslt is tuple: Returns tuple (`sequence_output`, `pooled_output`). With the fields: @@ -356,10 +382,26 @@ def forward(self, embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids) - encoder_outputs = self.encoder(embedding_output, attention_mask) - sequence_output = encoder_outputs + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + if paddle.is_tensor(encoder_outputs): + encoder_outputs = (encoder_outputs, ) + + sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) - return sequence_output, pooled_output + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions) def get_input_embeddings(self) -> nn.Embedding: """get skep input word embedding @@ -409,7 +451,11 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, - attention_mask=None): + attention_mask=None, + labels=None, + output_hidden_states=False, + output_attentions=False, + return_dict=False): r""" The SkepForSequenceClassification forward method, overrides the __call__() special method. @@ -422,10 +468,25 @@ def forward(self, See :class:`SkepModel`. attention_mask (Tensor, optional): See :class:`SkepModel`. + labels (Tensor of shape `(batch_size,)`, optional): + Labels for computing the sequence classification/regression loss. + Indices should be in `[0, ..., num_classes - 1]`. If `num_classes == 1` + a regression loss is computed (Mean-Square loss), If `num_classes > 1` + a classification loss is computed (Cross-Entropy). + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `logits`, a tensor of the input text classification logits. - Shape as `[batch_size, num_classes]` and dtype as float32. + An instance of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput`. Example: .. code-block:: @@ -441,14 +502,46 @@ def forward(self, logits = model(**inputs) """ - _, pooled_output = self.skep(input_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - attention_mask=attention_mask) + outputs = self.skep(input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) - return logits + + loss = None + if labels is not None: + if self.num_classes == 1: + loss_fct = paddle.nn.MSELoss() + loss = loss_fct(logits, labels) + elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape((-1, self.num_classes)), + labels.reshape((-1, ))) + else: + loss_fct = paddle.nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits, ) + outputs[2:] + if loss is not None: + return (loss, ) + output + if len(output) == 1: + return output[0] + return output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) class SkepForTokenClassification(SkepPretrainedModel): @@ -482,7 +575,11 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, - attention_mask=None): + attention_mask=None, + labels=None, + output_hidden_states=False, + output_attentions=False, + return_dict=False): r""" The SkepForTokenClassification forward method, overrides the __call__() special method. @@ -495,10 +592,22 @@ def forward(self, See :class:`SkepModel`. attention_mask (Tensor, optional): See :class:`SkepModel`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the token classification loss. Indices should be in `[0, ..., num_classes - 1]`. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `logits`, a tensor of the input token classification logits. - Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`. + An instance of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput`. Example: .. code-block:: @@ -514,14 +623,39 @@ def forward(self, logits = model(**inputs) """ - sequence_output, _ = self.skep(input_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - attention_mask=attention_mask) + outputs = self.skep(input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) - return logits + + loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape((-1, self.num_classes)), + labels.reshape((-1, ))) + + if not return_dict: + output = (logits, ) + outputs[2:] + if loss is not None: + return (loss, ) + output + if len(output) == 1: + return output[0] + return output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) class SkepCrfForTokenClassification(SkepPretrainedModel): @@ -564,7 +698,10 @@ def forward(self, position_ids=None, attention_mask=None, seq_lens=None, - labels=None): + labels=None, + output_hidden_states=False, + output_attentions=False, + return_dict=False): r""" The SkepCrfForTokenClassification forward method, overrides the __call__() special method. @@ -584,9 +721,22 @@ def forward(self, labels (Tensor, optional): The input label tensor. Its data type should be int64 and its shape is `[batch_size, sequence_length]`. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `loss` if `labels` is not None. Otherwise, returns tensor `prediction`. + An instance of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput`. + + if return_dict is False, Returns tensor `loss` if `labels` is not None. Otherwise, returns tensor `prediction`. - `loss` (Tensor): The crf loss. Its data type is float32 and its shape is `[batch_size]`. @@ -596,13 +746,15 @@ def forward(self, Its data type is int64 and its shape is `[batch_size, sequence_length]`. """ - sequence_output, _ = self.skep(input_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - attention_mask=attention_mask) - - bigru_output, _ = self.gru( - sequence_output) #, sequence_length=seq_lens) + outputs = self.skep(input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + bigru_output, _ = self.gru(outputs[0]) #, sequence_length=seq_lens) emission = self.fc(bigru_output) if seq_lens is None: @@ -616,9 +768,22 @@ def forward(self, seq_lens = paddle.ones(shape=[input_ids_shape[0]], dtype=paddle.int64) * input_ids_shape[1] + loss, prediction = None, None if labels is not None: loss = self.crf_loss(emission, seq_lens, labels) - return loss else: _, prediction = self.viterbi_decoder(emission, seq_lens) + + # FIXME(wj-Mcat): the output of this old version model is single tensor when return_dict is False + if not return_dict: + # when loss is None, return prediction + if labels is not None: + return loss return prediction + + return TokenClassifierOutput( + loss=loss, + logits=prediction, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/tests/transformers/skep/test_modeling.py b/tests/transformers/skep/test_modeling.py index 03e2ed87cefe..b3016eaf2c58 100644 --- a/tests/transformers/skep/test_modeling.py +++ b/tests/transformers/skep/test_modeling.py @@ -17,6 +17,7 @@ from typing import Optional, Tuple, Dict, Any import paddle from paddle import Tensor +from parameterized import parameterized_class from dataclasses import dataclass, asdict, fields, Field from paddlenlp.transformers import ( @@ -70,6 +71,8 @@ class SkepTestConfig(SkepTestModelConfig): # used for sequence classification num_classes: int = 3 + num_choices: int = 3 + type_sequence_label_size: int = 3 class SkepModelTester: @@ -82,6 +85,11 @@ def __init__(self, parent, config: Optional[SkepTestConfig] = None): self.is_training = self.config.is_training + def __getattr__(self, key: str): + if not hasattr(self.config, key): + raise AttributeError(f'attribute <{key}> not exist') + return getattr(self.config, key) + def prepare_config_and_inputs( self) -> Tuple[Dict[str, Any], Tensor, Tensor, Tensor]: config = self.config @@ -98,23 +106,36 @@ def prepare_config_and_inputs( token_type_ids = ids_tensor([config.batch_size, config.seq_length], config.type_vocab_size) - return config.model_kwargs, input_ids, token_type_ids, input_mask + sequence_labels = None + token_labels = None + choice_labels = None + + if self.parent.use_labels: + sequence_labels = ids_tensor([self.batch_size], + self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], + self.num_classes) + choice_labels = ids_tensor([self.batch_size], self.num_choices) - def create_and_check_model( - self, - config, - input_ids: Tensor, - token_type_ids: Tensor, - input_mask: Tensor, - ): + config = self.get_config() + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def create_and_check_model(self, config, input_ids: Tensor, + token_type_ids: Tensor, input_mask: Tensor, + sequence_labels: Tensor, token_labels: Tensor, + choice_labels: Tensor): model = SkepModel(**config) model.eval() result = model(input_ids, attention_mask=input_mask, - token_type_ids=token_type_ids) - result = model(input_ids, token_type_ids=token_type_ids) - result = model(input_ids) + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict) + result = model(input_ids, + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict) + result = model(input_ids, return_dict=self.parent.return_dict) + self.parent.assertEqual(result[0].shape, [ self.config.batch_size, self.config.seq_length, self.config.hidden_size @@ -123,60 +144,83 @@ def create_and_check_model( result[1].shape, [self.config.batch_size, self.config.hidden_size]) def create_and_check_for_sequence_classification( - self, - config, - input_ids: Tensor, - token_type_ids: Tensor, - input_mask: Tensor, - ): + self, config, input_ids: Tensor, token_type_ids: Tensor, + input_mask: Tensor, sequence_labels: Tensor, token_labels: Tensor, + choice_labels: Tensor): model = SkepForSequenceClassification( SkepModel(**config), num_classes=self.config.num_classes) model.eval() - result = model( - input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - ) + result = model(input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict, + labels=sequence_labels) + + if token_labels is not None: + result = result[1:] + elif paddle.is_tensor(result): + result = [result] + self.parent.assertEqual( - result.shape, [self.config.batch_size, self.config.num_classes]) + result[0].shape, [self.config.batch_size, self.config.num_classes]) def create_and_check_for_token_classification( - self, - config, - input_ids, - token_type_ids, - input_mask, - ): + self, config, input_ids: Tensor, token_type_ids: Tensor, + input_mask: Tensor, sequence_labels: Tensor, token_labels: Tensor, + choice_labels: Tensor): model = SkepForTokenClassification(SkepModel(**config), num_classes=self.config.num_classes) model.eval() result = model(input_ids, attention_mask=input_mask, - token_type_ids=token_type_ids) - self.parent.assertEqual(result.shape, [ + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict, + labels=token_labels) + + if token_labels is not None: + result = result[1:] + elif paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [ self.config.batch_size, self.config.seq_length, self.config.num_classes ]) def create_and_check_for_crf_token_classification( - self, - config, - input_ids, - token_type_ids, - input_mask, - ): + self, config, input_ids: Tensor, token_type_ids: Tensor, + input_mask: Tensor, sequence_labels: Tensor, token_labels: Tensor, + choice_labels: Tensor): model = SkepCrfForTokenClassification( SkepModel(**config), num_classes=self.config.num_classes) model.eval() result = model(input_ids, attention_mask=input_mask, - token_type_ids=token_type_ids) - self.parent.assertEqual( - result.shape, [self.config.batch_size, self.config.seq_length]) + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict, + labels=token_labels) + # TODO(wj-Mcat): the output of SkepCrfForTokenClassification is wrong + if paddle.is_tensor(result): + result = [result] + + if token_labels is not None: + self.parent.assertEqual(result[0].shape, [self.config.batch_size]) + else: + self.parent.assertEqual( + result[0].shape, + [self.config.batch_size, self.config.seq_length]) def prepare_config_and_inputs_for_common(self): - config, input_ids, token_type_ids, input_mask = self.prepare_config_and_inputs( - ) + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs inputs_dict = { "input_ids": input_ids, "token_type_ids": token_type_ids, @@ -193,12 +237,19 @@ def get_config(self) -> dict: return self.config.model_kwargs +@parameterized_class(("return_dict", "use_labels"), [ + [False, False], + [False, True], + [True, False], + [True, True], +]) class SkepModelTest(ModelTesterMixin, unittest.TestCase): base_model_class = SkepModel + return_dict = False + use_labels = False all_model_classes = ( SkepModel, - # TODO(wj-Mcat): to activate this model later SkepCrfForTokenClassification, SkepForSequenceClassification, SkepForTokenClassification, @@ -207,9 +258,6 @@ class SkepModelTest(ModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = SkepModelTester(self) - def get_config(): - pass - def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) @@ -225,11 +273,6 @@ def test_for_token_classification(self): *config_and_inputs) def test_for_crf_token_classification(self): - # TODO(wj-Mcat): to activate this method later - # self.skipTest( - # "skip for crf token classification: there are contains something wrong in paddle.text.viterib_decode" - # ) - # return config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_crf_token_classification( *config_and_inputs)