Skip to content

Commit

Permalink
[Model]Transformer (#186)
Browse files Browse the repository at this point in the history
* change the signature of node/edge filter

* upd filter

* Support multi-dimension node feature in SPMV

* push transformer

* remove some experimental settings

* stable version

* hotfix

* upd tutorial

* upd README

* merge

* remove redundency

* remove tqdm

* several changes

* Refactor

* Refactor

* tutorial train

* fixed a bug

* fixed perf issue

* upd

* change dir

* move un-related to contrib

* tutuorial code

* remove redundency

* upd

* upd

* upd

* upd

* improve viz

* universal done

* halt norm

* fixed a bug

* add draw graph

* fixed several bugs

* remove dependency on core

* upd format of README

* trigger

* trigger

* upd viz

* trigger

* add transformer tutorial

* fix tutorial

* fix readme

* small fix on tutorials

* url fix in readme

* fixed func link

* upd
  • Loading branch information
yzh119 committed Dec 7, 2018
1 parent 37feb47 commit 9f32554
Show file tree
Hide file tree
Showing 25 changed files with 2,849 additions and 4 deletions.
12 changes: 12 additions & 0 deletions examples/pytorch/transformer/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
*~
data/
scripts/
checkpoints/
log/
*__pycache__*
*.pdf
*.tar.gz
*.zip
*.pyc
*.lprof
*.swp
Empty file.
46 changes: 46 additions & 0 deletions examples/pytorch/transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Transformer in DGL
In this example we implement the [Transformer](https://arxiv.org/pdf/1706.03762.pdf) and [Universal Transformer](https://arxiv.org/abs/1807.03819) with ACT in DGL.

The folder contains training module and inferencing module (beam decoder) for Transformer and training module for Universal Transformer

## Requirements

- PyTorch 0.4.1+
- networkx
- tqdm

## Usage

- For training:

```
python translation_train.py [--gpus id1,id2,...] [--N #layers] [--dataset DATASET] [--batch BATCHSIZE] [--universal]
```

- For evaluating BLEU score on test set(by enabling `--print` to see translated text):

```
python translation_test.py [--gpu id] [--N #layers] [--dataset DATASET] [--batch BATCHSIZE] [--checkpoint CHECKPOINT] [--print] [--universal]
```

Available datasets: `copy`, `sort`, `wmt14`, `multi30k`(default).

## Test Results

### Transfomer

- Multi30k: we achieve BLEU score 35.41 with default setting on Multi30k dataset, without using pre-trained embeddings. (if we set the number of layers to 2, the BLEU score could reach 36.45).
- WMT14: work in progress

### Universal Transformer

- work in progress

## Notes

- Currently we do not support Multi-GPU training(this will be fixed soon), you should only specifiy only one gpu\_id when running the training script.

## Reference

- [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html)
- [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/)
178 changes: 178 additions & 0 deletions examples/pytorch/transformer/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from .graph import *
from .fields import *
from .utils import prepare_dataset
import os
import numpy as np

class ClassificationDataset:
"Dataset class for classification task."
def __init__(self):
raise NotImplementedError

class TranslationDataset:
'''
Dataset class for translation task.
By default, the source language shares the same vocabulary with the target language.
'''
INIT_TOKEN = '<sos>'
EOS_TOKEN = '<eos>'
PAD_TOKEN = '<pad>'
MAX_LENGTH = 50
def __init__(self, path, exts, train='train', valid='valid', test='test', vocab='vocab.txt', replace_oov=None):
vocab_path = os.path.join(path, vocab)
self.src = {}
self.tgt = {}
with open(os.path.join(path, train + '.' + exts[0]), 'r') as f:
self.src['train'] = f.readlines()
with open(os.path.join(path, train + '.' + exts[1]), 'r') as f:
self.tgt['train'] = f.readlines()
with open(os.path.join(path, valid + '.' + exts[0]), 'r') as f:
self.src['valid'] = f.readlines()
with open(os.path.join(path, valid + '.' + exts[1]), 'r') as f:
self.tgt['valid'] = f.readlines()
with open(os.path.join(path, test + '.' + exts[0]), 'r') as f:
self.src['test'] = f.readlines()
with open(os.path.join(path, test + '.' + exts[1]), 'r') as f:
self.tgt['test'] = f.readlines()

if not os.path.exists(vocab_path):
self._make_vocab(vocab_path)

vocab = Vocab(init_token=self.INIT_TOKEN,
eos_token=self.EOS_TOKEN,
pad_token=self.PAD_TOKEN,
unk_token=replace_oov)
vocab.load(vocab_path)
self.vocab = vocab
strip_func = lambda x: x[:self.MAX_LENGTH]
self.src_field = Field(vocab,
preprocessing=None,
postprocessing=strip_func)
self.tgt_field = Field(vocab,
preprocessing=lambda seq: [self.INIT_TOKEN] + seq + [self.EOS_TOKEN],
postprocessing=strip_func)

def get_seq_by_id(self, idx, mode='train', field='src'):
"get raw sequence in dataset by specifying index, mode(train/valid/test), field(src/tgt)"
if field == 'src':
return self.src[mode][idx].strip().split()
else:
return [self.INIT_TOKEN] + self.tgt[mode][idx].strip().split() + [self.EOS_TOKEN]

def _make_vocab(self, path, thres=2):
word_dict = {}
for mode in ['train', 'valid', 'test']:
for line in self.src[mode] + self.tgt[mode]:
for token in line.strip().split():
if token not in word_dict:
word_dict[token] = 0
else:
word_dict[token] += 1

with open(path, 'w') as f:
for k, v in word_dict.items():
if v > 2:
print(k, file=f)

@property
def vocab_size(self):
return len(self.vocab)

@property
def pad_id(self):
return self.vocab[self.PAD_TOKEN]

@property
def sos_id(self):
return self.vocab[self.INIT_TOKEN]

@property
def eos_id(self):
return self.vocab[self.EOS_TOKEN]

def __call__(self, graph_pool, mode='train', batch_size=32, k=1, devices=['cpu']):
'''
Create a batched graph correspond to the mini-batch of the dataset.
args:
graph_pool: a GraphPool object for accelerating.
mode: train/valid/test
batch_size: batch size
devices: ['cpu'] or a list of gpu ids.
k: beam size(only required for test)
'''
dev_id, gs = 0, []
src_data, tgt_data = self.src[mode], self.tgt[mode]
n = len(src_data)
order = np.random.permutation(n) if mode == 'train' else range(n)
src_buf, tgt_buf = [], []

for idx in order:
src_sample = self.src_field(
src_data[idx].strip().split())
tgt_sample = self.tgt_field(
tgt_data[idx].strip().split())
src_buf.append(src_sample)
tgt_buf.append(tgt_sample)
if len(src_buf) == batch_size:
if mode == 'test':
assert len(devices) == 1 # we only allow single gpu for inference
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0])
else:
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id]))
dev_id += 1
if dev_id == len(devices):
yield gs if len(devices) > 1 else gs[0]
dev_id, gs = 0, []
src_buf, tgt_buf = [], []

if len(src_buf) != 0:
if mode == 'test':
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0])
else:
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id]))
yield gs if len(devices) > 1 else gs[0]

def get_sequence(self, batch):
"return a list of sequence from a list of index arrays"
ret = []
filter_list = set([self.pad_id, self.sos_id, self.eos_id])
for seq in batch:
try:
l = seq.index(self.eos_id)
except:
l = len(seq)
ret.append(' '.join(self.vocab[token] for token in seq[:l] if not token in filter_list))
return ret

def get_dataset(dataset):
"we wrapped a set of datasets as example"
prepare_dataset(dataset)
if dataset == 'babi':
raise NotImplementedError
elif dataset == 'copy' or dataset == 'sort':
return TranslationDataset(
'data/{}'.format(dataset),
('in', 'out'),
train='train',
valid='valid',
test='test',
)
elif dataset == 'multi30k':
return TranslationDataset(
'data/multi30k',
('en.atok', 'de.atok'),
train='train',
valid='val',
test='test2016',
replace_oov='<unk>'
)
elif dataset == 'wmt14':
return TranslationDataset(
'data/wmt14',
('en', 'de'),
train='train.tok.clean.bpe.32000',
valid='newstest2013.tok.bpe.32000',
test='newstest2014.tok.bpe.32000',
vocab='vocab.bpe.32000')
else:
raise KeyError()
63 changes: 63 additions & 0 deletions examples/pytorch/transformer/dataset/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
class Vocab:
def __init__(self, init_token=None, eos_token=None, pad_token=None, unk_token=None):
self.init_token = init_token
self.eos_token = eos_token
self.pad_token = pad_token
self.unk_token = unk_token
self.vocab_lst = []
self.vocab_dict = None

def load(self, path):
if self.init_token is not None:
self.vocab_lst.append(self.init_token)
if self.eos_token is not None:
self.vocab_lst.append(self.eos_token)
if self.pad_token is not None:
self.vocab_lst.append(self.pad_token)
if self.unk_token is not None:
self.vocab_lst.append(self.unk_token)
with open(path, 'r') as f:
for token in f.readlines():
token = token.strip()
self.vocab_lst.append(token)
self.vocab_dict = {
v: k for k, v in enumerate(self.vocab_lst)
}

def __len__(self):
return len(self.vocab_lst)

def __getitem__(self, key):
if isinstance(key, str):
if key in self.vocab_dict:
return self.vocab_dict[key]
else:
return self.vocab_dict[self.unk_token]
else:
return self.vocab_lst[key]

class Field:
def __init__(self, vocab, preprocessing=None, postprocessing=None):
self.vocab = vocab
self.preprocessing = preprocessing
self.postprocessing = postprocessing

def preprocess(self, x):
if self.preprocessing is not None:
return self.preprocessing(x)
return x

def postprocess(self, x):
if self.postprocessing is not None:
return self.postprocessing(x)
return x

def numericalize(self, x):
return [self.vocab[token] for token in x]

def __call__(self, x):
return self.postprocess(
self.numericalize(
self.preprocess(x)
)
)
Loading

0 comments on commit 9f32554

Please sign in to comment.