-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sequence tagging demo #225
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#!/bin/bash | ||
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
set -e | ||
|
||
DIR="$( cd "$(dirname "$0")" ; pwd -P )" | ||
cd $DIR | ||
|
||
wget http://www.cnts.ua.ac.be/conll2000/chunking/train.txt.gz | ||
wget http://www.cnts.ua.ac.be/conll2000/chunking/test.txt.gz |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
data/test.txt.gz |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
data/train.txt.gz |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.trainer.PyDataProvider2 import * | ||
import gzip | ||
import logging | ||
|
||
logging.basicConfig( | ||
format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s', | ||
) | ||
logger = logging.getLogger('paddle') | ||
logger.setLevel(logging.INFO) | ||
|
||
OOV_POLICY_IGNORE = 0 | ||
OOV_POLICY_USE = 1 | ||
OOV_POLICY_ERROR = 2 | ||
|
||
num_original_columns = 3 | ||
|
||
# Feature combination patterns. | ||
# [[-1,0], [0,0]] means previous token at column 0 and current token at | ||
# column 0 are combined as one feature. | ||
patterns = [ | ||
[[-2,0]], | ||
[[-1,0]], | ||
[[0,0]], | ||
[[1,0]], | ||
[[2,0]], | ||
|
||
[[-1,0], [0,0]], | ||
[[0,0], [1,0]], | ||
|
||
[[-2,1]], | ||
[[-1,1]], | ||
[[0,1]], | ||
[[1,1]], | ||
[[2,1]], | ||
[[-2,1], [-1,1]], | ||
[[-1,1], [0,1]], | ||
[[0,1], [1,1]], | ||
[[1,1], [2,1]], | ||
|
||
[[-2,1], [-1,1], [0,1]], | ||
[[-1,1], [0,1], [1,1]], | ||
[[0,1], [1,1], [2,1]], | ||
] | ||
|
||
dict_label = { | ||
'B-ADJP': 0, | ||
'I-ADJP': 1, | ||
'B-ADVP': 2, | ||
'I-ADVP': 3, | ||
'B-CONJP': 4, | ||
'I-CONJP': 5, | ||
'B-INTJ': 6, | ||
'I-INTJ': 7, | ||
'B-LST': 8, | ||
'I-LST': 9, | ||
'B-NP': 10, | ||
'I-NP': 11, | ||
'B-PP': 12, | ||
'I-PP': 13, | ||
'B-PRT': 14, | ||
'I-PRT': 15, | ||
'B-SBAR': 16, | ||
'I-SBAR': 17, | ||
'B-UCP': 18, | ||
'I-UCP': 19, | ||
'B-VP': 20, | ||
'I-VP': 21, | ||
'O': 22 | ||
} | ||
|
||
def make_features(sequence): | ||
length = len(sequence) | ||
num_features = len(sequence[0]) | ||
def get_features(pos): | ||
if pos < 0: | ||
return ['#B%s' % -pos] * num_features | ||
if pos >= length: | ||
return ['#E%s' % (pos - length + 1)] * num_features | ||
return sequence[pos] | ||
|
||
for i in xrange(length): | ||
for pattern in patterns: | ||
fname = '/'.join([get_features(i+pos)[f] for pos, f in pattern]) | ||
sequence[i].append(fname) | ||
|
||
''' | ||
Source file format: | ||
Each line is for one timestep. The features are separated by space. | ||
An empty line indicates end of a sequence. | ||
|
||
cutoff: a list of numbers. If count of a feature is smaller than this, | ||
it will be ignored. | ||
if oov_policy[i] is OOV_POLICY_USE, id 0 is reserved for OOV features of | ||
i-th column. | ||
|
||
return a list of dict for each column | ||
''' | ||
def create_dictionaries(filename, cutoff, oov_policy): | ||
def add_to_dict(sequence, dicts): | ||
num_features = len(dicts) | ||
for features in sequence: | ||
l = len(features) | ||
assert l == num_features, "Wrong number of features " + line | ||
for i in xrange(l): | ||
if features[i] in dicts[i]: | ||
dicts[i][features[i]] += 1 | ||
else: | ||
dicts[i][features[i]] = 1 | ||
|
||
num_features = len(cutoff) | ||
dicts = [] | ||
for i in xrange(num_features): | ||
dicts.append(dict()) | ||
|
||
f = gzip.open(filename, 'rb') | ||
|
||
sequence = [] | ||
|
||
for line in f: | ||
line = line.strip() | ||
if not line: | ||
make_features(sequence) | ||
add_to_dict(sequence, dicts) | ||
sequence = [] | ||
continue | ||
features = line.split(' ') | ||
sequence.append(features) | ||
|
||
|
||
for i in xrange(num_features): | ||
dct = dicts[i] | ||
n = 1 if oov_policy[i] == OOV_POLICY_USE else 0 | ||
todo = [] | ||
for k, v in dct.iteritems(): | ||
if v < cutoff[i]: | ||
todo.append(k) | ||
else: | ||
dct[k] = n | ||
n += 1 | ||
|
||
if oov_policy[i] == OOV_POLICY_USE: | ||
# placeholder so that len(dct) will be the number of features | ||
# including OOV | ||
dct['#OOV#'] = 0 | ||
|
||
logger.info('column %d dict size=%d, ignored %d' % (i, n, len(todo))) | ||
for k in todo: | ||
del dct[k] | ||
|
||
f.close() | ||
return dicts | ||
|
||
|
||
def initializer(settings, **xargs): | ||
cutoff = [3, 1, 0] | ||
cutoff += [3] * len(patterns) | ||
oov_policy = [OOV_POLICY_IGNORE, OOV_POLICY_ERROR, OOV_POLICY_ERROR] | ||
oov_policy += [OOV_POLICY_IGNORE] * len(patterns) | ||
dicts = create_dictionaries('data/train.txt.gz', cutoff, oov_policy) | ||
dicts[2] = dict_label | ||
settings.dicts = dicts | ||
settings.oov_policy = oov_policy | ||
input_types = [] | ||
num_features = len(dicts) | ||
for i in xrange(num_original_columns): | ||
input_types.append(integer_sequence(len(dicts[i]))) | ||
logger.info("slot %s size=%s" % (i, len(dicts[i]))) | ||
if patterns: | ||
dim = 0 | ||
for i in xrange(num_original_columns, num_features): | ||
dim += len(dicts[i]) | ||
input_types.append(sparse_binary_vector_sequence(dim)) | ||
logger.info("feature size=%s" % dim) | ||
settings.input_types = input_types | ||
|
||
''' | ||
if oov_policy[i] == OOV_POLICY_USE, features in i-th column which are not | ||
existed in dicts[i] will be assigned to id 0. | ||
if oov_policy[i] == OOV_POLICY_ERROR, all features in i-th column MUST exist | ||
in dicts[i]. | ||
''' | ||
@provider(init_hook=initializer, cache=CacheType.CACHE_PASS_IN_MEM) | ||
def process(settings, filename): | ||
input_file = filename | ||
dicts = settings.dicts | ||
oov_policy = settings.oov_policy | ||
|
||
def gen_sample(sequence): | ||
num_features = len(dicts) | ||
sample = [list() for i in xrange(num_original_columns)] | ||
if patterns: | ||
sample.append([]) | ||
for features in sequence: | ||
assert len(features) == num_features, \ | ||
"Wrong number of features: " + line | ||
for i in xrange(num_original_columns): | ||
id = dicts[i].get(features[i], -1) | ||
if id != -1: | ||
sample[i].append(id) | ||
elif oov_policy[i] == OOV_POLICY_IGNORE: | ||
sample[i].append(0xffffffff) | ||
elif oov_policy[i] == OOV_POLICY_ERROR: | ||
logger.fatal("Unknown token: %s" % features[i]) | ||
else: | ||
sample[i].append(0) | ||
|
||
if patterns: | ||
dim = 0 | ||
vec = [] | ||
for i in xrange(num_original_columns, num_features): | ||
id = dicts[i].get(features[i], -1) | ||
if id != -1: | ||
vec.append(dim + id) | ||
elif oov_policy[i] == OOV_POLICY_IGNORE: | ||
pass | ||
elif oov_policy[i] == OOV_POLICY_ERROR: | ||
logger.fatal("Unknown token: %s" % features[i]) | ||
else: | ||
vec.ids.append(dim + 0) | ||
|
||
dim += len(dicts[i]) | ||
sample[-1].append(vec) | ||
return sample | ||
|
||
num_features = len(dicts) | ||
f = gzip.open(input_file, 'rb') | ||
|
||
num_sequences = 0 | ||
sequence = [] | ||
for line in f: | ||
line = line.strip() | ||
if not line: | ||
make_features(sequence) | ||
yield gen_sample(sequence) | ||
sequence = [] | ||
num_sequences += 1 | ||
continue | ||
features = line.split(' ') | ||
sequence.append(features) | ||
|
||
f.close() | ||
|
||
logger.info("num_sequences=%s" % num_sequences) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.trainer_config_helpers import * | ||
|
||
import math | ||
|
||
define_py_data_sources2(train_list="data/train.list", | ||
test_list="data/test.list", | ||
module="dataprovider", | ||
obj="process") | ||
|
||
|
||
batch_size = 1 | ||
settings( | ||
learning_method=MomentumOptimizer(), | ||
batch_size=batch_size, | ||
regularization=L2Regularization(batch_size * 1e-4), | ||
average_window=0.5, | ||
learning_rate=1e-1, | ||
learning_rate_decay_a=1e-5, | ||
learning_rate_decay_b=0.25, | ||
) | ||
|
||
num_label_types=23 | ||
|
||
def get_simd_size(size): | ||
return int(math.ceil(float(size) / 8)) * 8 | ||
|
||
# Currently, in order to use sparse_update=True, | ||
# the size has to be aligned. | ||
num_label_types = get_simd_size(num_label_types) | ||
|
||
features = data_layer(name="features", size=76328) | ||
word = data_layer(name="word", size=6778) | ||
pos = data_layer(name="pos", size=44) | ||
chunk = data_layer(name="chunk", | ||
size=num_label_types) | ||
|
||
crf_input = fc_layer( | ||
input=features, | ||
size=num_label_types, | ||
act=LinearActivation(), | ||
bias_attr=False, | ||
param_attr=ParamAttr(initial_std=0, sparse_update=True)) | ||
|
||
crf=crf_layer( | ||
input=crf_input, | ||
label=chunk, | ||
param_attr=ParamAttr(name="crfw", initial_std=0), | ||
) | ||
|
||
crf_decoding=crf_decoding_layer( | ||
size=num_label_types, | ||
input=crf_input, | ||
label=chunk, | ||
param_attr=ParamAttr(name="crfw"), | ||
) | ||
|
||
sum_evaluator( | ||
name="error", | ||
input=crf_decoding, | ||
) | ||
|
||
chunk_evaluator( | ||
name="chunk_f1", | ||
input =[crf_decoding, chunk], | ||
chunk_scheme="IOB", | ||
num_chunk_types=11, | ||
) | ||
|
||
inputs(word, pos, chunk, features) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. inputs() accepts *args. |
||
outputs(crf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
frame?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fname is abbreviation of feature name