Skip to content

Commit

Permalink
Update wmt14ende dataset (PaddlePaddle#98)
Browse files Browse the repository at this point in the history
* update datasets

* update

* update

* update dataset

* delete data_files for example transformer
  • Loading branch information
FrostML committed Mar 11, 2021
1 parent bc24d2f commit 0a111bb
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 58 deletions.
5 changes: 1 addition & 4 deletions benchmark/transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@
同时,我们提供了一份已经处理好的数据集,可以编写如下代码,对应的数据集将会自动下载并且解压到 `~/.paddlenlp/datasets/machine_translation/WMT14ende/`。这部分已经在 reader.py 中有写明,若无自行修改可以无需编写相应代码。

``` python
# 获取默认的数据处理方式
transform_func = WMT14ende.get_default_transform_func(root=root)
# 下载并处理 WMT14.en-de 翻译数据集
dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func)
datasets = load_dataset('wmt14ende', data_files=data_files, splits=('train', 'dev'))
```

### 单机训练
Expand Down
2 changes: 1 addition & 1 deletion benchmark/transformer/configs/transformer.base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use_gpu: True
pool_size: 200000
sort_type: "global"
batch_size: 4096
infer_batch_size: 64
infer_batch_size: 8
shuffle_batch: True
# Data shuffle only works when sort_type is pool or none
shuffle: True
Expand Down
86 changes: 54 additions & 32 deletions benchmark/transformer/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import numpy as np
from paddle.io import BatchSampler, DataLoader, Dataset
import paddle.distributed as dist
from paddlenlp.data import Pad
from paddlenlp.datasets import WMT14ende
from paddlenlp.data import Pad, Vocab
from paddlenlp.datasets import load_dataset
from paddlenlp.data.sampler import SamplerHelper


Expand All @@ -34,30 +34,41 @@ def min_max_filer(data, max_len, min_len=0):


def create_data_loader(args, places=None, use_all_vocab=False):
root = None if args.root == "None" else args.root
if not use_all_vocab:
WMT14ende.VOCAB_INFO = (os.path.join(
"WMT14.en-de", "wmt14_ende_data_bpe",
"vocab_all.bpe.33712"), os.path.join(
"WMT14.en-de", "wmt14_ende_data_bpe", "vocab_all.bpe.33712"),
"de485e3c2e17e23acf4b4b70b54682dd",
"de485e3c2e17e23acf4b4b70b54682dd")
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
data_files = None
if args.root != "None" and os.path.exists(args.root):
data_files = {
'train': (os.path.join(args.root, "train.tok.clean.bpe.33708.en"),
os.path.join(args.root, "train.tok.clean.bpe.33708.de")),
'dev': (os.path.join(args.root, "newstest2013.tok.bpe.33708.en"),
os.path.join(args.root, "newstest2013.tok.bpe.33708.de"))
}

datasets = load_dataset(
'wmt14ende', data_files=data_files, splits=('train', 'dev'))
if use_all_vocab:
src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["all"])
else:
src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["benchmark"])
trg_vocab = src_vocab

padding_vocab = (
lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
)
args.src_vocab_size = padding_vocab(len(src_vocab))
args.trg_vocab_size = padding_vocab(len(trg_vocab))
transform_func = WMT14ende.get_default_transform_func(root=root)
datasets = [
WMT14ende.get_datasets(
mode=m, root=root, transform_func=transform_func)
for m in ["train", "dev"]
]

def convert_samples(sample):
source = sample['src'].split()
target = sample['trg'].split()

source = src_vocab.to_indices(source)
target = trg_vocab.to_indices(target)

return source, target

data_loaders = [(None)] * 2
for i, dataset in enumerate(datasets):
dataset = dataset.filter(
dataset = dataset.map(convert_samples, lazy=False).filter(
partial(
min_max_filer, max_len=args.max_length))
batch_sampler = TransformerBatchSampler(
Expand Down Expand Up @@ -91,25 +102,36 @@ def create_data_loader(args, places=None, use_all_vocab=False):


def create_infer_loader(args, use_all_vocab=False):
root = None if args.root == "None" else args.root
if not use_all_vocab:
WMT14ende.VOCAB_INFO = (os.path.join(
"WMT14.en-de", "wmt14_ende_data_bpe",
"vocab_all.bpe.33712"), os.path.join(
"WMT14.en-de", "wmt14_ende_data_bpe", "vocab_all.bpe.33712"),
"de485e3c2e17e23acf4b4b70b54682dd",
"de485e3c2e17e23acf4b4b70b54682dd")
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
data_files = None
if args.root != "None" and os.path.exists(args.root):
data_files = {
'test': (os.path.join(args.root, "newstest2014.tok.bpe.33708.en"),
os.path.join(args.root, "newstest2014.tok.bpe.33708.de"))
}

dataset = load_dataset('wmt14ende', data_files=data_files, splits=('test'))
if use_all_vocab:
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["all"])
else:
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"])
trg_vocab = src_vocab

padding_vocab = (
lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
)
args.src_vocab_size = padding_vocab(len(src_vocab))
args.trg_vocab_size = padding_vocab(len(trg_vocab))
transform_func = WMT14ende.get_default_transform_func(root=root)
dataset = WMT14ende.get_datasets(
mode="test", root=root, transform_func=transform_func).filter(
partial(
min_max_filer, max_len=args.max_length))

def convert_samples(sample):
source = sample['src'].split()
target = sample['trg'].split()

source = src_vocab.to_indices(source)
target = trg_vocab.to_indices(target)

return source, target

dataset = dataset.map(convert_samples, lazy=False)

batch_sampler = SamplerHelper(dataset).batch(
batch_size=args.infer_batch_size, drop_last=False)
Expand Down
5 changes: 1 addition & 4 deletions examples/machine_translation/transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ pip install attrdict pyyaml
同时,我们提供了一份已经处理好的数据集,可以编写如下代码,对应的数据集将会自动下载并且解压到 `~/.paddlenlp/datasets/machine_translation/WMT14ende/`

``` python
# 获取默认的数据处理方式
transform_func = WMT14ende.get_default_transform_func(root=root)
# 下载并处理 WMT14.en-de 翻译数据集
dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func)
datasets = load_dataset('wmt14ende', data_files=data_files, splits=('train', 'dev'))
```

## 单机训练
Expand Down
48 changes: 31 additions & 17 deletions examples/machine_translation/transformer/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import numpy as np
from paddle.io import BatchSampler, DataLoader, Dataset
import paddle.distributed as dist
from paddlenlp.data import Pad
from paddlenlp.data import Pad, Vocab
from paddlenlp.datasets import load_dataset
from paddlenlp.datasets import WMT14ende
from paddlenlp.data.sampler import SamplerHelper

Expand All @@ -34,23 +35,28 @@ def min_max_filer(data, max_len, min_len=0):


def create_data_loader(args):
root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
datasets = load_dataset('wmt14ende', splits=('train', 'dev'))
src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["all"])
trg_vocab = src_vocab

padding_vocab = (
lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
)
args.src_vocab_size = padding_vocab(len(src_vocab))
args.trg_vocab_size = padding_vocab(len(trg_vocab))
transform_func = WMT14ende.get_default_transform_func(root=root)
datasets = [
WMT14ende.get_datasets(
mode=m, root=root, transform_func=transform_func)
for m in ["train", "dev"]
]

def convert_samples(sample):
source = sample['src'].split()
target = sample['trg'].split()

source = src_vocab.to_indices(source)
target = trg_vocab.to_indices(target)

return source, target

data_loaders = [(None)] * 2
for i, dataset in enumerate(datasets):
dataset = dataset.filter(
dataset = dataset.map(convert_samples, lazy=False).filter(
partial(
min_max_filer, max_len=args.max_length))
batch_sampler = TransformerBatchSampler(
Expand Down Expand Up @@ -81,18 +87,26 @@ def create_data_loader(args):


def create_infer_loader(args):
root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
dataset = load_dataset('wmt14ende', splits=('test'))
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["all"])
trg_vocab = src_vocab

padding_vocab = (
lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
)
args.src_vocab_size = padding_vocab(len(src_vocab))
args.trg_vocab_size = padding_vocab(len(trg_vocab))
transform_func = WMT14ende.get_default_transform_func(root=root)
dataset = WMT14ende.get_datasets(
mode="test", root=root, transform_func=transform_func).filter(
partial(
min_max_filer, max_len=args.max_length))

def convert_samples(sample):
source = sample['src'].split()
target = sample['trg'].split()

source = src_vocab.to_indices(source)
target = trg_vocab.to_indices(target)

return source, target

dataset = dataset.map(convert_samples, lazy=False)

batch_sampler = SamplerHelper(dataset).batch(
batch_size=args.infer_batch_size, drop_last=False)
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/datasets/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@
from .drcd import *
from .dureader_robust import *
from .glue import *
from .wmt14ende import *
from .cnndm import *
from .couplet import *
124 changes: 124 additions & 0 deletions paddlenlp/datasets/experimental/wmt14ende.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import collections
import os
import warnings

from paddle.io import Dataset
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.env import DATA_HOME
from . import DatasetBuilder

__all__ = ['WMT14ende']


class WMT14ende(DatasetBuilder):
URL = "https://paddlenlp.bj.bcebos.com/datasets/WMT14.en-de.tar.gz"
META_INFO = collections.namedtuple('META_INFO', ('src_file', 'tgt_file',
'src_md5', 'tgt_md5'))
SPLITS = {
'train': META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"train.tok.clean.bpe.33708.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"train.tok.clean.bpe.33708.de"),
"c7c0b77e672fc69f20be182ae37ff62c",
"1865ece46948fda1209d3b7794770a0a"),
'dev': META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2013.tok.bpe.33708.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2013.tok.bpe.33708.de"),
"aa4228a4bedb6c45d67525fbfbcee75e",
"9b1eeaff43a6d5e78a381a9b03170501"),
'test': META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2014.tok.bpe.33708.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"newstest2014.tok.bpe.33708.de"),
"c9403eacf623c6e2d9e5a1155bdff0b5",
"0058855b55e37c4acfcb8cffecba1050"),
'dev-eval': META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2013.tok.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2013.tok.de"),
"d74712eb35578aec022265c439831b0e",
"6ff76ced35b70e63a61ecec77a1c418f"),
'test-eval': META_INFO(
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2014.tok.en"),
os.path.join("WMT14.en-de", "wmt14_ende_data",
"newstest2014.tok.de"),
"8cce2028e4ca3d4cc039dfd33adbfb43",
"a1b1f4c47f487253e1ac88947b68b3b8")
}
VOCAB_INFO = [(os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"vocab_all.bpe.33708"),
"2fc775b7df37368e936a8e1f63846bb0"),
(os.path.join("WMT14.en-de", "wmt14_ende_data_bpe",
"vocab_all.bpe.33712"),
"de485e3c2e17e23acf4b4b70b54682dd")]
UNK_TOKEN = "<unk>"
BOS_TOKEN = "<s>"
EOS_TOKEN = "<e>"

MD5 = "a2b8410709ff760a3b40b84bd62dfbd8"

def _get_data(self, mode, **kwargs):
default_root = os.path.join(DATA_HOME, self.__class__.__name__)
src_filename, tgt_filename, src_data_hash, tgt_data_hash = self.SPLITS[
mode]
src_fullname = os.path.join(default_root, src_filename)
tgt_fullname = os.path.join(default_root, tgt_filename)

(all_vocab_filename, all_vocab_hash), (sub_vocab_filename,
sub_vocab_hash) = self.VOCAB_INFO
all_vocab_fullname = os.path.join(default_root, all_vocab_filename)
sub_vocab_fullname = os.path.join(default_root, sub_vocab_filename)

if (not os.path.exists(src_fullname) or
(src_data_hash and not md5file(src_fullname) == src_data_hash)) or (
not os.path.exists(tgt_fullname) or
(tgt_data_hash and
not md5file(tgt_fullname) == tgt_data_hash)) or (
not os.path.exists(all_vocab_fullname) or
(all_vocab_hash and
not md5file(all_vocab_fullname) == all_vocab_hash)) or (
not os.path.exists(sub_vocab_fullname) or
(sub_vocab_hash and
not md5file(sub_vocab_fullname) == sub_vocab_hash)):
get_path_from_url(self.URL, default_root, self.MD5)

return src_fullname, tgt_fullname

def _read(self, filename, *args):
src_filename, tgt_filename = filename
with open(src_filename, 'r', encoding='utf-8') as src_f:
with open(tgt_filename, 'r', encoding='utf-8') as tgt_f:
for src_line, tgt_line in zip(src_f, tgt_f):
src_line = src_line.strip()
tgt_line = tgt_line.strip()
if not src_line and not tgt_line:
continue
yield {"src": src_line, "trg": tgt_line}

def get_vocab(self):
all_vocab_fullname = os.path.join(DATA_HOME, self.__class__.__name__,
self.VOCAB_INFO[0][0])
sub_vocab_fullname = os.path.join(DATA_HOME, self.__class__.__name__,
self.VOCAB_INFO[1][0])
vocab_info = {
'all': {
'filepath': all_vocab_fullname,
'unk_token': self.UNK_TOKEN,
'bos_token': self.BOS_TOKEN,
'eos_token': self.EOS_TOKEN
},
'benchmark': {
'filepath': sub_vocab_fullname,
'unk_token': self.UNK_TOKEN,
'bos_token': self.BOS_TOKEN,
'eos_token': self.EOS_TOKEN
}
}
return vocab_info

0 comments on commit 0a111bb

Please sign in to comment.