From 0a111bb681e725e2bb9f7db1e49dedcae1d540ca Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Thu, 11 Mar 2021 11:31:11 +0800 Subject: [PATCH] Update wmt14ende dataset (#98) * update datasets * update * update * update dataset * delete data_files for example transformer --- benchmark/transformer/README.md | 5 +- .../transformer/configs/transformer.base.yaml | 2 +- benchmark/transformer/reader.py | 86 +++++++----- .../machine_translation/transformer/README.md | 5 +- .../machine_translation/transformer/reader.py | 48 ++++--- paddlenlp/datasets/experimental/__init__.py | 1 + paddlenlp/datasets/experimental/wmt14ende.py | 124 ++++++++++++++++++ 7 files changed, 213 insertions(+), 58 deletions(-) create mode 100644 paddlenlp/datasets/experimental/wmt14ende.py diff --git a/benchmark/transformer/README.md b/benchmark/transformer/README.md index af4769e55fc18a..058de053c1cef2 100644 --- a/benchmark/transformer/README.md +++ b/benchmark/transformer/README.md @@ -38,10 +38,7 @@ 同时,我们提供了一份已经处理好的数据集,可以编写如下代码,对应的数据集将会自动下载并且解压到 `~/.paddlenlp/datasets/machine_translation/WMT14ende/`。这部分已经在 reader.py 中有写明,若无自行修改可以无需编写相应代码。 ``` python -# 获取默认的数据处理方式 -transform_func = WMT14ende.get_default_transform_func(root=root) -# 下载并处理 WMT14.en-de 翻译数据集 -dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func) +datasets = load_dataset('wmt14ende', data_files=data_files, splits=('train', 'dev')) ``` ### 单机训练 diff --git a/benchmark/transformer/configs/transformer.base.yaml b/benchmark/transformer/configs/transformer.base.yaml index 5f162be873002f..a30eae6f218e86 100644 --- a/benchmark/transformer/configs/transformer.base.yaml +++ b/benchmark/transformer/configs/transformer.base.yaml @@ -26,7 +26,7 @@ use_gpu: True pool_size: 200000 sort_type: "global" batch_size: 4096 -infer_batch_size: 64 +infer_batch_size: 8 shuffle_batch: True # Data shuffle only works when sort_type is pool or none shuffle: True diff --git a/benchmark/transformer/reader.py b/benchmark/transformer/reader.py index fb651f27a17b7e..07195635651fd3 100644 --- a/benchmark/transformer/reader.py +++ b/benchmark/transformer/reader.py @@ -21,8 +21,8 @@ import numpy as np from paddle.io import BatchSampler, DataLoader, Dataset import paddle.distributed as dist -from paddlenlp.data import Pad -from paddlenlp.datasets import WMT14ende +from paddlenlp.data import Pad, Vocab +from paddlenlp.datasets import load_dataset from paddlenlp.data.sampler import SamplerHelper @@ -34,30 +34,41 @@ def min_max_filer(data, max_len, min_len=0): def create_data_loader(args, places=None, use_all_vocab=False): - root = None if args.root == "None" else args.root - if not use_all_vocab: - WMT14ende.VOCAB_INFO = (os.path.join( - "WMT14.en-de", "wmt14_ende_data_bpe", - "vocab_all.bpe.33712"), os.path.join( - "WMT14.en-de", "wmt14_ende_data_bpe", "vocab_all.bpe.33712"), - "de485e3c2e17e23acf4b4b70b54682dd", - "de485e3c2e17e23acf4b4b70b54682dd") - (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root) + data_files = None + if args.root != "None" and os.path.exists(args.root): + data_files = { + 'train': (os.path.join(args.root, "train.tok.clean.bpe.33708.en"), + os.path.join(args.root, "train.tok.clean.bpe.33708.de")), + 'dev': (os.path.join(args.root, "newstest2013.tok.bpe.33708.en"), + os.path.join(args.root, "newstest2013.tok.bpe.33708.de")) + } + + datasets = load_dataset( + 'wmt14ende', data_files=data_files, splits=('train', 'dev')) + if use_all_vocab: + src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["all"]) + else: + src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["benchmark"]) + trg_vocab = src_vocab + padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor ) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) - transform_func = WMT14ende.get_default_transform_func(root=root) - datasets = [ - WMT14ende.get_datasets( - mode=m, root=root, transform_func=transform_func) - for m in ["train", "dev"] - ] + + def convert_samples(sample): + source = sample['src'].split() + target = sample['trg'].split() + + source = src_vocab.to_indices(source) + target = trg_vocab.to_indices(target) + + return source, target data_loaders = [(None)] * 2 for i, dataset in enumerate(datasets): - dataset = dataset.filter( + dataset = dataset.map(convert_samples, lazy=False).filter( partial( min_max_filer, max_len=args.max_length)) batch_sampler = TransformerBatchSampler( @@ -91,25 +102,36 @@ def create_data_loader(args, places=None, use_all_vocab=False): def create_infer_loader(args, use_all_vocab=False): - root = None if args.root == "None" else args.root - if not use_all_vocab: - WMT14ende.VOCAB_INFO = (os.path.join( - "WMT14.en-de", "wmt14_ende_data_bpe", - "vocab_all.bpe.33712"), os.path.join( - "WMT14.en-de", "wmt14_ende_data_bpe", "vocab_all.bpe.33712"), - "de485e3c2e17e23acf4b4b70b54682dd", - "de485e3c2e17e23acf4b4b70b54682dd") - (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root) + data_files = None + if args.root != "None" and os.path.exists(args.root): + data_files = { + 'test': (os.path.join(args.root, "newstest2014.tok.bpe.33708.en"), + os.path.join(args.root, "newstest2014.tok.bpe.33708.de")) + } + + dataset = load_dataset('wmt14ende', data_files=data_files, splits=('test')) + if use_all_vocab: + src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["all"]) + else: + src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"]) + trg_vocab = src_vocab + padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor ) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) - transform_func = WMT14ende.get_default_transform_func(root=root) - dataset = WMT14ende.get_datasets( - mode="test", root=root, transform_func=transform_func).filter( - partial( - min_max_filer, max_len=args.max_length)) + + def convert_samples(sample): + source = sample['src'].split() + target = sample['trg'].split() + + source = src_vocab.to_indices(source) + target = trg_vocab.to_indices(target) + + return source, target + + dataset = dataset.map(convert_samples, lazy=False) batch_sampler = SamplerHelper(dataset).batch( batch_size=args.infer_batch_size, drop_last=False) diff --git a/examples/machine_translation/transformer/README.md b/examples/machine_translation/transformer/README.md index dcc933b6add9ec..5fc0ccb5962aa7 100644 --- a/examples/machine_translation/transformer/README.md +++ b/examples/machine_translation/transformer/README.md @@ -65,10 +65,7 @@ pip install attrdict pyyaml 同时,我们提供了一份已经处理好的数据集,可以编写如下代码,对应的数据集将会自动下载并且解压到 `~/.paddlenlp/datasets/machine_translation/WMT14ende/`。 ``` python -# 获取默认的数据处理方式 -transform_func = WMT14ende.get_default_transform_func(root=root) -# 下载并处理 WMT14.en-de 翻译数据集 -dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func) +datasets = load_dataset('wmt14ende', data_files=data_files, splits=('train', 'dev')) ``` ## 单机训练 diff --git a/examples/machine_translation/transformer/reader.py b/examples/machine_translation/transformer/reader.py index dc820165cd2fc6..21d8001fdb0724 100644 --- a/examples/machine_translation/transformer/reader.py +++ b/examples/machine_translation/transformer/reader.py @@ -21,7 +21,8 @@ import numpy as np from paddle.io import BatchSampler, DataLoader, Dataset import paddle.distributed as dist -from paddlenlp.data import Pad +from paddlenlp.data import Pad, Vocab +from paddlenlp.datasets import load_dataset from paddlenlp.datasets import WMT14ende from paddlenlp.data.sampler import SamplerHelper @@ -34,23 +35,28 @@ def min_max_filer(data, max_len, min_len=0): def create_data_loader(args): - root = None if args.root == "None" else args.root - (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root) + datasets = load_dataset('wmt14ende', splits=('train', 'dev')) + src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["all"]) + trg_vocab = src_vocab + padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor ) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) - transform_func = WMT14ende.get_default_transform_func(root=root) - datasets = [ - WMT14ende.get_datasets( - mode=m, root=root, transform_func=transform_func) - for m in ["train", "dev"] - ] + + def convert_samples(sample): + source = sample['src'].split() + target = sample['trg'].split() + + source = src_vocab.to_indices(source) + target = trg_vocab.to_indices(target) + + return source, target data_loaders = [(None)] * 2 for i, dataset in enumerate(datasets): - dataset = dataset.filter( + dataset = dataset.map(convert_samples, lazy=False).filter( partial( min_max_filer, max_len=args.max_length)) batch_sampler = TransformerBatchSampler( @@ -81,18 +87,26 @@ def create_data_loader(args): def create_infer_loader(args): - root = None if args.root == "None" else args.root - (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root) + dataset = load_dataset('wmt14ende', splits=('test')) + src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["all"]) + trg_vocab = src_vocab + padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor ) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) - transform_func = WMT14ende.get_default_transform_func(root=root) - dataset = WMT14ende.get_datasets( - mode="test", root=root, transform_func=transform_func).filter( - partial( - min_max_filer, max_len=args.max_length)) + + def convert_samples(sample): + source = sample['src'].split() + target = sample['trg'].split() + + source = src_vocab.to_indices(source) + target = trg_vocab.to_indices(target) + + return source, target + + dataset = dataset.map(convert_samples, lazy=False) batch_sampler = SamplerHelper(dataset).batch( batch_size=args.infer_batch_size, drop_last=False) diff --git a/paddlenlp/datasets/experimental/__init__.py b/paddlenlp/datasets/experimental/__init__.py index b6c289071959c6..792bf11d3ae0ad 100644 --- a/paddlenlp/datasets/experimental/__init__.py +++ b/paddlenlp/datasets/experimental/__init__.py @@ -28,5 +28,6 @@ from .drcd import * from .dureader_robust import * from .glue import * +from .wmt14ende import * from .cnndm import * from .couplet import * diff --git a/paddlenlp/datasets/experimental/wmt14ende.py b/paddlenlp/datasets/experimental/wmt14ende.py new file mode 100644 index 00000000000000..f29aaebbd5e4fe --- /dev/null +++ b/paddlenlp/datasets/experimental/wmt14ende.py @@ -0,0 +1,124 @@ +import collections +import os +import warnings + +from paddle.io import Dataset +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder + +__all__ = ['WMT14ende'] + + +class WMT14ende(DatasetBuilder): + URL = "https://paddlenlp.bj.bcebos.com/datasets/WMT14.en-de.tar.gz" + META_INFO = collections.namedtuple('META_INFO', ('src_file', 'tgt_file', + 'src_md5', 'tgt_md5')) + SPLITS = { + 'train': META_INFO( + os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", + "train.tok.clean.bpe.33708.en"), + os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", + "train.tok.clean.bpe.33708.de"), + "c7c0b77e672fc69f20be182ae37ff62c", + "1865ece46948fda1209d3b7794770a0a"), + 'dev': META_INFO( + os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", + "newstest2013.tok.bpe.33708.en"), + os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", + "newstest2013.tok.bpe.33708.de"), + "aa4228a4bedb6c45d67525fbfbcee75e", + "9b1eeaff43a6d5e78a381a9b03170501"), + 'test': META_INFO( + os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", + "newstest2014.tok.bpe.33708.en"), + os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", + "newstest2014.tok.bpe.33708.de"), + "c9403eacf623c6e2d9e5a1155bdff0b5", + "0058855b55e37c4acfcb8cffecba1050"), + 'dev-eval': META_INFO( + os.path.join("WMT14.en-de", "wmt14_ende_data", + "newstest2013.tok.en"), + os.path.join("WMT14.en-de", "wmt14_ende_data", + "newstest2013.tok.de"), + "d74712eb35578aec022265c439831b0e", + "6ff76ced35b70e63a61ecec77a1c418f"), + 'test-eval': META_INFO( + os.path.join("WMT14.en-de", "wmt14_ende_data", + "newstest2014.tok.en"), + os.path.join("WMT14.en-de", "wmt14_ende_data", + "newstest2014.tok.de"), + "8cce2028e4ca3d4cc039dfd33adbfb43", + "a1b1f4c47f487253e1ac88947b68b3b8") + } + VOCAB_INFO = [(os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", + "vocab_all.bpe.33708"), + "2fc775b7df37368e936a8e1f63846bb0"), + (os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", + "vocab_all.bpe.33712"), + "de485e3c2e17e23acf4b4b70b54682dd")] + UNK_TOKEN = "" + BOS_TOKEN = "" + EOS_TOKEN = "" + + MD5 = "a2b8410709ff760a3b40b84bd62dfbd8" + + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + src_filename, tgt_filename, src_data_hash, tgt_data_hash = self.SPLITS[ + mode] + src_fullname = os.path.join(default_root, src_filename) + tgt_fullname = os.path.join(default_root, tgt_filename) + + (all_vocab_filename, all_vocab_hash), (sub_vocab_filename, + sub_vocab_hash) = self.VOCAB_INFO + all_vocab_fullname = os.path.join(default_root, all_vocab_filename) + sub_vocab_fullname = os.path.join(default_root, sub_vocab_filename) + + if (not os.path.exists(src_fullname) or + (src_data_hash and not md5file(src_fullname) == src_data_hash)) or ( + not os.path.exists(tgt_fullname) or + (tgt_data_hash and + not md5file(tgt_fullname) == tgt_data_hash)) or ( + not os.path.exists(all_vocab_fullname) or + (all_vocab_hash and + not md5file(all_vocab_fullname) == all_vocab_hash)) or ( + not os.path.exists(sub_vocab_fullname) or + (sub_vocab_hash and + not md5file(sub_vocab_fullname) == sub_vocab_hash)): + get_path_from_url(self.URL, default_root, self.MD5) + + return src_fullname, tgt_fullname + + def _read(self, filename, *args): + src_filename, tgt_filename = filename + with open(src_filename, 'r', encoding='utf-8') as src_f: + with open(tgt_filename, 'r', encoding='utf-8') as tgt_f: + for src_line, tgt_line in zip(src_f, tgt_f): + src_line = src_line.strip() + tgt_line = tgt_line.strip() + if not src_line and not tgt_line: + continue + yield {"src": src_line, "trg": tgt_line} + + def get_vocab(self): + all_vocab_fullname = os.path.join(DATA_HOME, self.__class__.__name__, + self.VOCAB_INFO[0][0]) + sub_vocab_fullname = os.path.join(DATA_HOME, self.__class__.__name__, + self.VOCAB_INFO[1][0]) + vocab_info = { + 'all': { + 'filepath': all_vocab_fullname, + 'unk_token': self.UNK_TOKEN, + 'bos_token': self.BOS_TOKEN, + 'eos_token': self.EOS_TOKEN + }, + 'benchmark': { + 'filepath': sub_vocab_fullname, + 'unk_token': self.UNK_TOKEN, + 'bos_token': self.BOS_TOKEN, + 'eos_token': self.EOS_TOKEN + } + } + return vocab_info