From 01be7792c846cfd14a104eff5ed163f7447e0f0f Mon Sep 17 00:00:00 2001 From: westfish Date: Tue, 11 Oct 2022 10:38:37 +0000 Subject: [PATCH 1/4] add qg-taskflow --- docs/model_zoo/taskflow.md | 51 ++- paddlenlp/taskflow/question_generation.py | 490 ++++++++++++++++++++++ paddlenlp/taskflow/taskflow.py | 17 + paddlenlp/transformers/unimo/modeling.py | 21 + paddlenlp/transformers/unimo/tokenizer.py | 5 + 5 files changed, 583 insertions(+), 1 deletion(-) create mode 100644 paddlenlp/taskflow/question_generation.py diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md index ee14dc79c12d..299dd703b2b9 100644 --- a/docs/model_zoo/taskflow.md +++ b/docs/model_zoo/taskflow.md @@ -44,7 +44,7 @@ PaddleNLP提供**开箱即用**的产业级NLP预置任务能力,无需训练 | [文图生成](#文图生成) | `Taskflow("text_to_image")` | ✅ | ✅ | ✅ | | | 文图生成大模型 | | [文本摘要](#文本摘要) | `Taskflow("text_summarization")` | ✅ | ✅ | ✅ | ✅ | | 文本摘要大模型 | | [文档智能](#文档智能) | `Taskflow("document_intelligence")` | ✅ | ✅ | ✅ | ✅ | | 基于跨模态通用文档预训练模型ERNIE-LayoutX | - +| [问题生成](#问题生成) | `Taskflow("question_generation")` | ✅ | ✅ | ✅ | ✅ | | 问题生成大模型 | ## QuickStart @@ -1620,6 +1620,55 @@ from paddlenlp import Taskflow +### 问题生成 +
  通过UNIMO-Text模型来根据上下文和答案生成问题
+ +#### 支持单条、批量预测 + +```python +>>> from paddlenlp import Taskflow +# 默认模型为 unimo-text-1.0-dureader_qg-template1 +>>> question_generator = Taskflow("question_generation") +# 单条输入 +>>> question_generator([ + {"context": "奇峰黄山千米以上的山峰有77座,整座黄山就是一座花岗岩的峰林,自古有36大峰,36小峰,最高峰莲花峰、最险峰天都峰和观日出的最佳点光明顶构成黄山的三大主峰。", "answer": "莲花峰"} + ]) +''' + ['黄山最高峰是什么'] +''' +# 多条输入 +>>> question_generator([ + {"context": "奇峰黄山千米以上的山峰有77座,整座黄山就是一座花岗岩的峰林,自古有36大峰,36小峰,最高峰莲花峰、最险峰天都峰和观日出的最佳点光明顶构成黄山的三大主峰。", "answer": "莲花峰"}, + {"context": "弗朗索瓦·韦达外文名:franciscusvieta国籍:法国出生地:普瓦图出生日期:1540年逝世日期:1603年12月13日职业:数学家主要成就:为近代数学的发展奠定了基础。", "answer": "法国"} + ]) +''' + ['黄山最高峰是什么', '弗朗索瓦是哪里人'] +''' +``` + +#### 可配置参数说明 +* `model`:可选模型,默认为unimo-text-1.0-dureader_qg-template1,支持的模型支持的模型有["unimo-text-1.0", "unimo-text-1.0-dureader_qg-template1", ]。 +* `device`:运行设备,默认为"gpu"。 +* `template`:模版,可选项有[0, 1, 2, 3],1表示使用默认模版,0表示不使用模版。 +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 +* `output_scores`:是否要输出解码得分,默认为False。 +* `is_select_from_num_return_sequences`:是否对多个返回序列挑选最优项输出,当为True时,若num_return_sequences不为1则自动根据解码得分选择得分最高的序列最为最终结果,否则返回num_return_sequences个序列,默认为True。 +* `max_length`:生成代码的最大长度,默认为50。 +* `min_length`:生成代码的最小长度,默认为3。 +* `decode_strategy`:解码策略,支持beam_search和sampling,默认为beam_search。 +* `temperature`:解码参数temperature,默认为1.0。 +* `top_k`:解码参数top_k,默认为0。 +* `top_p`:解码参数top_p,默认为1.0。 +* `num_beams`:解码参数num_beams,表示beam_search解码的beam size,默认为6。 +* `num_beam_groups`:解码参数num_beam_groups,默认为1。 +* `diversity_rate`:解码参数diversity_rate,默认为0.0。 +* `length_penalty`:解码长度控制值,默认为1.2。 +* `num_return_sequences`:解码返回序列数,默认为1。 +* `repetition_penalty`:解码重复惩罚值,默认为1。 +* `use_faster`:表示是否开启基于FasterTransformer的高性能预测,注意FasterTransformer的高性能预测仅支持gpu,默认为False。 +* `use_fp16_decoding`: 表示在开启高性能预测的时候是否使用fp16来完成预测过程,若不使用则使用fp32,默认为True。 + +
## PART Ⅱ   定制化训练 diff --git a/paddlenlp/taskflow/question_generation.py b/paddlenlp/taskflow/question_generation.py new file mode 100644 index 000000000000..4a16571c4a0a --- /dev/null +++ b/paddlenlp/taskflow/question_generation.py @@ -0,0 +1,490 @@ +# coding:utf-8 +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import json +import math +import os +import copy +import itertools +import math + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ..transformers import UNIMOLMHeadModel +from ..transformers import UNIMOTokenizer + +from ..datasets import load_dataset +from ..data import Stack, Pad, Tuple +from .utils import download_file, add_docstrings, static_mode_guard, dygraph_mode_guard +from .task import Task + +usage = r""" + from paddlenlp import Taskflow + + question_generation = Taskflow("question_generation") + question_generation([{"context": "奇峰黄山千米以上的山峰有77座,整座黄山就是一座花岗岩的峰林,自古有36大峰,36小峰,最高峰莲花峰、最险峰天都峰和观日出的最佳点光明顶构成黄山的三大主峰。", "answer": "莲花峰"}]]) + ''' + ['黄山最高峰是什么'] + ''' + """ + + +class QuestionGenerationTask(Task): + """ + The text summarization model to predict the summary of an input text. + Args: + task(string): The name of task. + model(string): The model name in the task. + kwargs (dict, optional): Additional keyword arguments passed along to the specific task. + """ + + def __init__(self, task, model, **kwargs): + super().__init__(task=task, model=model, **kwargs) + paddle.set_device(kwargs.get("device", 'gpu')) + self._batch_size = kwargs.get("batch_size", 16) + self._output_scores = kwargs.get("output_scores", False) + self._is_select_from_num_return_sequences = kwargs.get( + "is_select_from_num_return_sequences", True) + self._construct_tokenizer(model) + self._construct_model(model) + # Hypter-parameter during generating. + self._max_length = kwargs.get("max_length", 50) + self._min_length = kwargs.get("min_length", 3) + self._decode_strategy = kwargs.get("decode_strategy", 'beam_search') + self._temperature = kwargs.get("temperature", 1.0) + self._top_k = kwargs.get("top_k", 0) + self._top_p = kwargs.get("top_p", 1.0) + self._num_beams = kwargs.get("num_beams", 6) + self._num_beam_groups = kwargs.get("num_beam_groups", 1) + self._diversity_rate = kwargs.get("diversity_rate", 0.0) + self._length_penalty = kwargs.get("length_penalty", 1.2) + self._num_return_sequences = kwargs.get("num_return_sequences", 1) + self._repetition_penalty = kwargs.get("repetition_penalty", 1) + self._use_faster = kwargs.get("use_faster", False) + self._use_fp16_decoding = kwargs.get("use_fp16_decoding", False) + self._template = kwargs.get("template", 1) + + def _construct_model(self, model): + """ + Construct the inference model for the predictor. + """ + # self._model = UNIMOLMHeadModel.from_pretrained(model) + self._model = UNIMOLMHeadModel.from_pretrained(self._task_path) + self._model.eval() + + def _construct_tokenizer(self, model): + """ + Construct the tokenizer for the predictor. + """ + self._tokenizer = UNIMOTokenizer.from_pretrained(self._task_path) + + def _preprocess(self, inputs): + """ + Transform the raw text to the model inputs, two steps involved: + 1) Transform the raw text to token ids. + 2) Generate the other model inputs from the raw text and token ids. + """ + inputs = self._check_input_text(inputs) + batches = self._batchify(inputs, self._batch_size) + outputs = {'batches': batches, 'text': inputs} + return outputs + + def _batchify(self, data, batch_size): + """ + Generate input batches. + """ + examples = [self._convert_example(i) for i in data] + # Seperates data into some batches. + one_batch = [] + for example in examples: + one_batch.append(example) + if len(one_batch) == batch_size: + yield self._parse_batch(one_batch, self._tokenizer.pad_token_id) + one_batch = [] + if one_batch: + yield self._parse_batch(one_batch, self._tokenizer.pad_token_id) + + def _check_input_text(self, inputs): + inputs = inputs[0] + if isinstance(inputs, str): + if len(inputs) == 0: + raise ValueError( + "Invalid inputs, input text should not be empty text, please check your input." + .format(type(inputs))) + inputs = [inputs] + elif isinstance(inputs, dict): + if not ('source' in inputs and 'title' in inputs) and not ( + 'context' in inputs and 'answer' in inputs): + raise TypeError( + "Invalid inputs, source and title are not in the input dictionary, nor are context and answer." + ) + elif isinstance(inputs, list): + if not (isinstance(inputs[0], dict)): + raise TypeError( + "Invalid inputs, input text should be list of dict.".format( + type(inputs[0]))) + else: + raise TypeError( + "Invalid inputs, input text should be str or list of str, but type of {} found!" + .format(type(inputs))) + return inputs + + def _convert_example(self, + example, + max_seq_len=512, + return_length=True, + template=1): + """ + Convert all examples into necessary features. + """ + if isinstance(example, dict): + target = None + if 'source' in example and 'title' in example: + source = example['source'] + title = None + if 'title' in example.keys(): + title = example['title'] + elif 'context' in example and 'answer' in example: + source = example['context'] + title = None + if 'answer' in example.keys(): + title = example['answer'] + else: + assert False, "Source and title are not in the input dictionary, nor are context and answer." + if 'target' in example.keys(): + target = example['target'] + elif isinstance(example, list): + source = example[0] + title = example[1] + + if self._template == 1: + ### use template 1 + source = '答案:' + title + self._tokenizer.sep_token + '上下文:' + source + title = None + if target: + target = '问题:' + target + elif self._template == 2: + ### use template 2 + source = '答案:' + title + self._tokenizer.sep_token + '上下文:' + source + title = None + if target: + target = '在已知答案的前提下,问题:' + target + elif self._template == 3: + ### use template 3 + source = '这是一个问题生成任务,根据提供的答案和上下文,来生成问题。' + title + tokenizer.sep_token + '上下文:' + source + title = None + if target: + target = '问题:' + target + + tokenized_example = self._tokenizer.gen_encode( + source, + title=title, + max_seq_len=max_seq_len, + max_title_len=30, + add_start_token_for_decoding=True, + return_position_ids=True, + ) + + if 'target' in example and example['target']: + tokenized_example['target'] = example['target'] + # Use to gather the logits corresponding to the labels during training + return tokenized_example + + def _parse_batch(self, batch_examples, pad_val, pad_right=False): + """ + Batchify a batch of examples. + """ + + def pad_mask(batch_attention_mask): + """Pad attention_mask.""" + batch_size = len(batch_attention_mask) + max_len = max(map(len, batch_attention_mask)) + attention_mask = np.ones( + (batch_size, max_len, max_len), dtype='float32') * -1e9 + for i, mask_data in enumerate(attention_mask): + seq_len = len(batch_attention_mask[i]) + if pad_right: + mask_data[:seq_len:, :seq_len] = np.array( + batch_attention_mask[i], dtype='float32') + else: + mask_data[-seq_len:, + -seq_len:] = np.array(batch_attention_mask[i], + dtype='float32') + # In order to ensure the correct broadcasting mechanism, expand one + # dimension to the second dimension (n_head of Transformer). + attention_mask = np.expand_dims(attention_mask, axis=1) + return attention_mask + + pad_func = Pad(pad_val=pad_val, pad_right=pad_right, dtype='int64') + input_ids = pad_func( + [example['input_ids'] for example in batch_examples]) + token_type_ids = pad_func( + [example['token_type_ids'] for example in batch_examples]) + position_ids = pad_func( + [example['position_ids'] for example in batch_examples]) + attention_mask = pad_mask( + [example['attention_mask'] for example in batch_examples]) + # seq_len = np.asarray([example['seq_len'] for example in batch_examples], + # dtype='int32') + batch_dict = {} + batch_dict['input_ids'] = input_ids + batch_dict['token_type_ids'] = token_type_ids + batch_dict['position_ids'] = position_ids + batch_dict['attention_mask'] = attention_mask + # batch_dict['seq_len'] = seq_len + return batch_dict + + def _run_model(self, inputs): + """ + Run the task model from the outputs of the `_preprocess` function. + """ + all_ids = [] + all_scores = [] + + for batch in inputs["batches"]: + input_ids = paddle.to_tensor(batch['input_ids'], dtype='int64') + token_type_ids = paddle.to_tensor(batch['token_type_ids'], + dtype='int64') + position_ids = paddle.to_tensor(batch['position_ids'], + dtype='int64') + attention_mask = paddle.to_tensor(batch['attention_mask'], + dtype='float32') + # seq_len = paddle.to_tensor(batch['seq_len'], dtype='int64') + ids, scores = self._model.generate( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + max_length=self._max_length, + min_length=self._min_length, + decode_strategy=self._decode_strategy, + temperature=self._temperature, + top_k=self._top_k, + top_p=self._top_p, + num_beams=self._num_beams, + num_beam_groups=self._num_beam_groups, + diversity_rate=self._diversity_rate, + length_penalty=self._length_penalty, + num_return_sequences=self._num_return_sequences, + repetition_penalty=self._repetition_penalty, + bos_token_id=self._tokenizer.cls_token_id, + eos_token_id=self._tokenizer.mask_token_id, + use_faster=self._use_faster, + use_fp16_decoding=self._use_fp16_decoding) + all_ids.extend(ids) + all_scores.extend(scores) + inputs['ids'] = all_ids + inputs['scores'] = all_scores + return inputs + + def out_run_model(self, input_ids, token_type_ids, position_ids, + attention_mask): + """ + Debug used. + """ + all_ids = [] + all_scores = [] + # seq_len = paddle.to_tensor(batch['seq_len'], dtype='int64') + ids, scores = self._model.generate( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + max_length=self._max_length, + min_length=self._min_length, + decode_strategy=self._decode_strategy, + temperature=self._temperature, + top_k=self._top_k, + top_p=self._top_p, + num_beams=self._num_beams, + length_penalty=self._length_penalty, + num_return_sequences=self._num_return_sequences, + bos_token_id=self._tokenizer.cls_token_id, + eos_token_id=self._tokenizer.mask_token_id, + ) + all_ids.extend(ids) + all_scores.extend(scores) + + inputs = {} + inputs['ids'] = all_ids + inputs['scores'] = all_scores + return all_ids, all_scores + + def _postprocess(self, inputs): + """ + The model output is tag ids, this function will convert the model output to raw text. + """ + ids_list = inputs['ids'] + scores_list = inputs['scores'] + if self._is_select_from_num_return_sequences: + results = self._select_from_num_return_sequences( + ids_list, scores_list, self._max_length, + self._num_return_sequences) + else: + results = self._return_num_return_sequences( + ids_list, scores_list, self._max_length, + self._num_return_sequences) + output_tokens = [result[0] for result in results] + output_scores = [math.exp(result[1]) for result in results] + # output_scores = [[math.exp(s) for s in result[1]] if isinstance(result[1], list) else math.exp(result[1]) for result in results] + + if self._output_scores: + return output_tokens, output_scores + return output_tokens + + def _return_num_return_sequences(self, + ids, + scores, + max_dec_len=None, + num_return_sequences=1): + """ + Select generated sequence form several return sequences. + """ + results = [] + group = [] + tmp = [] + if scores is not None: + ids = [i.numpy() for i in ids] + scores = [i.numpy() for i in scores] + + if len(ids) != len(scores) or (len(ids) % + num_return_sequences) != 0: + raise ValueError( + "the length of `ids` is {}, but the `num_return_sequences` is {}" + .format(len(ids), num_return_sequences)) + + for pred, score in zip(ids, scores): + pred_token_ids, pred_tokens = self._post_process_decoded_sequence( + pred) + num_token = len(pred_token_ids) + target = "".join(pred_tokens) + target = self._remove_template(target) + # not ending + if max_dec_len is not None and num_token >= max_dec_len: + score -= 1e3 + tmp.append([target, score]) + if len(tmp) == num_return_sequences: + group.append(tmp) + tmp = [] + for preds in group: + preds = sorted(preds, key=lambda x: -x[1]) + for pred in preds: + results.append(pred) + else: + ids = ids.numpy() + for pred in ids: + pred_token_ids, pred_tokens = self._post_process_decoded_sequence( + pred) + num_token = len(pred_token_ids) + response = "".join(pred_tokens) + response = self._remove_template(response) + # TODO: Support return scores in FT. + tmp.append([response]) + if len(tmp) == num_return_sequences: + group.append(tmp) + tmp = [] + + for preds in group: + for pred in preds: + results.append(pred) + return results + + def _select_from_num_return_sequences(self, + ids, + scores, + max_dec_len=None, + num_return_sequences=1): + """ + Select generated sequence form several return sequences. + """ + results = [] + group = [] + tmp = [] + if scores is not None: + ids = [i.numpy() for i in ids] + scores = [i.numpy() for i in scores] + + if len(ids) != len(scores) or (len(ids) % + num_return_sequences) != 0: + raise ValueError( + "the length of `ids` is {}, but the `num_return_sequences` is {}" + .format(len(ids), num_return_sequences)) + + for pred, score in zip(ids, scores): + pred_token_ids, pred_tokens = self._post_process_decoded_sequence( + pred) + num_token = len(pred_token_ids) + target = "".join(pred_tokens) + target = self._remove_template(target) + # not ending + if max_dec_len is not None and num_token >= max_dec_len: + score -= 1e3 + tmp.append([target, score]) + if len(tmp) == num_return_sequences: + group.append(tmp) + tmp = [] + for preds in group: + preds = sorted(preds, key=lambda x: -x[1]) + results.append(preds[0]) + else: + ids = ids.numpy() + for pred in ids: + pred_token_ids, pred_tokens = self._post_process_decoded_sequence( + pred) + num_token = len(pred_token_ids) + response = "".join(pred_tokens) + response = self._remove_template(response) + # TODO: Support return scores in FT. + tmp.append([response]) + if len(tmp) == num_return_sequences: + group.append(tmp) + tmp = [] + + for preds in group: + results.append(preds[0]) + return results + + def _post_process_decoded_sequence(self, token_ids): + """Post-process the decoded sequence. Truncate from the first .""" + eos_pos = len(token_ids) + for i, tok_id in enumerate(token_ids): + if tok_id == self._tokenizer.mask_token_id: + eos_pos = i + break + token_ids = token_ids[:eos_pos] + tokens = self._tokenizer.convert_ids_to_tokens(token_ids) + tokens = self._tokenizer.merge_subword(tokens) + special_tokens = ['[UNK]'] + tokens = [token for token in tokens if token not in special_tokens] + return token_ids, tokens + + def _remove_template(self, instr): + """Remove template prefix of decoded sequence.""" + outstr = instr.strip('问题:') + outstr = instr.strip('在已知答案的前提下,问题:') + return outstr + + def _construct_input_spec(self): + """ + Construct the input spec for the convert dygraph model to static model. + """ + self._input_spec = [ + paddle.static.InputSpec(shape=[None, None], + dtype="int64", + name='input_ids'), + ] diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index 9f016118a820..4a7b3272a1b4 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -40,6 +40,7 @@ from .text_to_image import TextToImageGenerationTask, TextToImageDiscoDiffusionTask, TextToImageStableDiffusionTask from .text_summarization import TextSummarizationTask from .document_intelligence import DocPromptTask +from .question_generation import QuestionGenerationTask warnings.simplefilter(action='ignore', category=Warning, lineno=0, append=False) @@ -450,6 +451,22 @@ "model": "docprompt" } }, + "question_generation": { + "models": { + "unimo-text-1.0": { + "task_class": QuestionGenerationTask, + "task_flag": "question-generation-unimo-text-1.0", + }, + "unimo-text-1.0-dureader_qg-template1": { + "task_class": QuestionGenerationTask, + "task_flag": + "question-generation-unimo-text-1.0-dureader_qg-template1", + }, + }, + "default": { + "model": "unimo-text-1.0-dureader_qg-template1" + } + }, } support_schema_list = [ diff --git a/paddlenlp/transformers/unimo/modeling.py b/paddlenlp/transformers/unimo/modeling.py index 5a95845b02c8..ab6b18fe9c5b 100644 --- a/paddlenlp/transformers/unimo/modeling.py +++ b/paddlenlp/transformers/unimo/modeling.py @@ -114,6 +114,25 @@ class UNIMOPretrainedModel(PretrainedModel): "eos_token_id": 3, "mask_token_id": 3, }, + "unimo-text-1.0-dureader_qg-template1": { + "vocab_size": 18000, + "hidden_size": 768, + "num_hidden_layers": 12, + "num_attention_heads": 12, + "intermediate_size": 3072, + "hidden_act": "relu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "normalize_before": False, + "max_position_embeddings": 513, + "type_vocab_size": 4, + "initializer_range": 0.02, + "unk_token_id": 17963, + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 3, + "mask_token_id": 3, + }, } pretrained_resource_files_map = { "model_state": { @@ -125,6 +144,8 @@ class UNIMOPretrainedModel(PretrainedModel): "https://bj.bcebos.com/paddlenlp/models/transformers/unimo/unimo-text-1.0-large.pdparams", "unimo-text-1.0-summary": "https://bj.bcebos.com/paddlenlp/models/transformers/unimo/unimo-text-1.0-summary.pdparams", + "unimo-text-1.0-dureader_qg-template1": + "https://bj.bcebos.com/paddlenlp/models/transformers/unimo/unimo-text-1.0-dureader_qg-template1.pdparams" } } base_model_prefix = "unimo" diff --git a/paddlenlp/transformers/unimo/tokenizer.py b/paddlenlp/transformers/unimo/tokenizer.py index 2529dd5bcfc3..b9fc30bb9640 100644 --- a/paddlenlp/transformers/unimo/tokenizer.py +++ b/paddlenlp/transformers/unimo/tokenizer.py @@ -93,6 +93,8 @@ class UNIMOTokenizer(PretrainedTokenizer): "https://bj.bcebos.com/paddlenlp/models/transformers/unimo/unimo-text-1.0-large-vocab.txt", "unimo-text-1.0-summary": "https://bj.bcebos.com/paddlenlp/models/transformers/unimo/unimo-text-1.0-vocab.txt", + "unimo-text-1.0-dureader_qg-template1": + "https://bj.bcebos.com/paddlenlp/models/transformers/unimo/unimo-text-1.0-vocab.txt", } } pretrained_init_configuration = { @@ -107,6 +109,9 @@ class UNIMOTokenizer(PretrainedTokenizer): }, "unimo-text-1.0-summary": { "do_lower_case": True + }, + "unimo-text-1.0-dureader_qg-template1": { + "do_lower_case": True } } max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES From aa1bef48909724c8ce9147c21f09951bb644b111 Mon Sep 17 00:00:00 2001 From: westfish Date: Tue, 11 Oct 2022 11:42:40 +0000 Subject: [PATCH 2/4] fix code style --- docs/model_zoo/taskflow.md | 1 + paddlenlp/taskflow/taskflow.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md index 299dd703b2b9..e4add67c2fa2 100644 --- a/docs/model_zoo/taskflow.md +++ b/docs/model_zoo/taskflow.md @@ -1,6 +1,7 @@ # PaddleNLP一键预测功能:Taskflow API +

diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index 4a7b3272a1b4..86deee976ce2 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -458,7 +458,8 @@ "task_flag": "question-generation-unimo-text-1.0", }, "unimo-text-1.0-dureader_qg-template1": { - "task_class": QuestionGenerationTask, + "task_class": + QuestionGenerationTask, "task_flag": "question-generation-unimo-text-1.0-dureader_qg-template1", }, From 7fbd726ccf96f1419bbb659f62811ad082995272 Mon Sep 17 00:00:00 2001 From: westfish Date: Wed, 12 Oct 2022 13:29:49 +0000 Subject: [PATCH 3/4] modified according to zeyang's comments --- docs/model_zoo/taskflow.md | 2 +- paddlenlp/taskflow/question_generation.py | 2 +- paddlenlp/taskflow/taskflow.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md index e4add67c2fa2..557ff32ce4a6 100644 --- a/docs/model_zoo/taskflow.md +++ b/docs/model_zoo/taskflow.md @@ -1667,7 +1667,7 @@ from paddlenlp import Taskflow * `num_return_sequences`:解码返回序列数,默认为1。 * `repetition_penalty`:解码重复惩罚值,默认为1。 * `use_faster`:表示是否开启基于FasterTransformer的高性能预测,注意FasterTransformer的高性能预测仅支持gpu,默认为False。 -* `use_fp16_decoding`: 表示在开启高性能预测的时候是否使用fp16来完成预测过程,若不使用则使用fp32,默认为True。 +* `use_fp16_decoding`: 表示在开启高性能预测的时候是否使用fp16来完成预测过程,若不使用则使用fp32,默认为False。 diff --git a/paddlenlp/taskflow/question_generation.py b/paddlenlp/taskflow/question_generation.py index 4a16571c4a0a..ee6ea6ad8ee1 100644 --- a/paddlenlp/taskflow/question_generation.py +++ b/paddlenlp/taskflow/question_generation.py @@ -1,5 +1,5 @@ # coding:utf-8 -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index 86deee976ce2..4e97a13d9bba 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -455,13 +455,13 @@ "models": { "unimo-text-1.0": { "task_class": QuestionGenerationTask, - "task_flag": "question-generation-unimo-text-1.0", + "task_flag": "question_generation-unimo-text-1.0", }, "unimo-text-1.0-dureader_qg-template1": { "task_class": QuestionGenerationTask, "task_flag": - "question-generation-unimo-text-1.0-dureader_qg-template1", + "question_generation-unimo-text-1.0-dureader_qg-template1", }, }, "default": { From e83abbb47fee46bdb381396046362eac27b2ef35 Mon Sep 17 00:00:00 2001 From: westfish Date: Thu, 13 Oct 2022 07:27:48 +0000 Subject: [PATCH 4/4] fix some typos in qg-example readme --- examples/question_generation/unimo-text/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/question_generation/unimo-text/README.md b/examples/question_generation/unimo-text/README.md index dba602072473..1f5ee5d73307 100644 --- a/examples/question_generation/unimo-text/README.md +++ b/examples/question_generation/unimo-text/README.md @@ -14,7 +14,7 @@ - [数据准备](#数据准备) - [数据加载](#数据加载) - [数据处理](#数据处理) - - [从本地文件创建数据集(可选)](#从本地文件创建数据集(可选)) + - [从本地文件创建数据集-可选](#从本地文件创建数据集-可选) - [模型训练](#模型训练) - [模型预测](#模型预测) - [模型转换部署](#模型转换部署) @@ -117,8 +117,8 @@ train_ds, dev_ds = load_dataset('dureader_qg', splits=('train', 'dev')) 问题: ``` -#### 从本地文件创建数据集(可选) -在许多情况下,我们需要使用本地数据集来训练我们的文本分类模型,本项目支持使用固定格式本地数据集文件进行训练。 +#### 从本地文件创建数据集-可选 +在许多情况下,我们需要使用本地数据集来训练我们的问题生成模型,本项目支持使用固定格式本地数据集文件进行训练。 使用本地文件,只需要在模型训练时指定`train_file` 为本地训练数据地址,`predict_file` 为本地测试数据地址即可。 本地数据集目录结构如下: