Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

基于Fluid的多线程文本分类程序 #8267

Closed
guru4elephant opened this issue Feb 8, 2018 · 3 comments
Closed

基于Fluid的多线程文本分类程序 #8267

guru4elephant opened this issue Feb 8, 2018 · 3 comments
Assignees

Comments

@guru4elephant
Copy link
Member

guru4elephant commented Feb 8, 2018

尝试编写的多线程程序进行文本分类的例子,无法跑通,错误日志的信息难以帮助定位问题,请fluid内核程序相关同学看看。另外,多线程程序的书写规则是否有对应文档支持?

import sys
import numpy as np
import paddle.v2 as paddle
import paddle.fluid as fluid
def to_lodtensor(data, place):
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    return res

def load_vocab(filename):
    vocab = {}
    with open(filename) as f:
        wid = 0
        for line in f:
            vocab[line.strip()] = wid
            wid += 1
    return vocab

word_dict = load_vocab(sys.argv[1])
word_dict["<unk>"] = len(word_dict)
#vocabulary size
dict_dim = len(word_dict)

# embedding dim
emb_dim = 128

# hidden dim
hid_dim = 128

# hidden dim2
hid_dim2 = 96

# class num
class_dim = 2

data = fluid.layers.data(
    name="words", shape=[1], dtype="int64", lod_level=1)

# label data
label = fluid.layers.data(name="label", shape=[1], dtype="int64")

places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
    feat_ = pd.read_input(data)
    label_ = pd.read_input(label)
    emb = fluid.layers.embedding(input=feat_, 
                                 size=[dict_dim, emb_dim], 
                                 param_attr=fluid.ParamAttr(learning_rate=5.0))
    
    lstm_h, c = fluid.layers.dynamic_lstm(input=emb, size=hid_dim, is_reverse=False)
    lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
    fc1 = fluid.layers.fc(input=lstm_max, size=hid_dim2, act='tanh')
    prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
    cost = fluid.layers.cross_entropy(input=prediction, label=label)
    avg_cost = fluid.layers.mean(x=cost)
    pd.write_output(avg_cost)
    pd.write_output(prediction)

accuracy = fluid.evaluator.Accuracy(input=prediction, label=label)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
    test_target = accuracy.metrics + accuracy.states
    inference_program = fluid.io.get_inference_program(test_target)
BATCH_SIZE = 4

train_reader = paddle.batch(
    paddle.reader.shuffle(
        paddle.dataset.imdb.train(word_dict), buf_size=25000),
    batch_size=BATCH_SIZE)
test_reader = paddle.batch(
    paddle.reader.shuffle(
        paddle.dataset.imdb.test(word_dict), buf_size=25000),
    batch_size=BATCH_SIZE)

place = fluid.CPUPlace()
#place = fluid.CUDAPlace(0)
'''
def test(exe):
    accuracy.reset(exe)
    for batch_id, data in enumerate(test_reader()):
        input_seq = to_lodtensor(map(lambda x:x[0], data), place)
        y_data = np.array(map(lambda x: x[1], data)).astype("int64")
        y_data = y_data.reshape([-1, 1])
        acc = exe.run(inference_program, 
                      feed={"words": input_seq, 
                            "label": y_data})
    return accuracy.eval(exe)
'''
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
print(fluid.default_startup_program())

print(fluid.default_main_program())

exe.run(fluid.default_startup_program())
PASS_NUM = 30
for pass_id in xrange(PASS_NUM):
    accuracy.reset(exe)
    for data in train_reader():
        cost_val, acc_val = exe.run(fluid.default_main_program(),
                                    feed=feeder.feed(data),
                                    fetch_list=[avg_cost, accuracy.metrics[0]])
        pass_acc = accuracy.eval(exe)
    #pass_test_acc = test(exe)
    print("test_acc: %f" % pass_test_acc)

以上例子在docker环境中无法运行,commit id
commit 8dbbc9d
Merge: efc094f 13922fb
Author: Tao Luo luotao02@baidu.com
Date: Mon Feb 26 15:26:56 2018 +0800

Merge pull request #8566 from jacquesqiao/add-c-api-doc

add c-api quick start

提示的问题

Traceback (most recent call last):
  File "test_understand_sentiment_dynamic_lstm_dev_parallel.py", line 114, in <module>
    fetch_list=[avg_cost, accuracy.metrics[0]])
  File "/usr/local/lib/python2.7/dist-packages/paddle/fluid/executor.py", line 290, in run
    self.executor.run(program.desc, scope, 0, True, True)
paddle.fluid.core.EnforceNotMet: Cannot find variable fc_1.tmp_2@GRAD in the parent scope at [/home/dongdaxiang/Paddle/paddle/fluid/operators/detail/safe_ref.h:26]
PaddlePaddle Call Stacks:
0       0x7fc969dc60ccp paddle::platform::EnforceNotMet::EnforceNotMet(std::__exception_ptr::exception_ptr, char const*, int) + 572
1       0x7fc96a376640p
2       0x7fc96a38041fp paddle::operators::ParallelDoGradOp::RunImpl(paddle::framework::Scope const&, boost::variant<paddle::platform::CUDAPlace, paddle::platform::CPUPlace, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_> const&) const + 447
3       0x7fc96a4b3d98p paddle::framework::OperatorBase::Run(paddle::framework::Scope const&, boost::variant<paddle::platform::CUDAPlace, paddle::platform::CPUPlace, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_, boost::detail::variant::void_> const&) + 72
4       0x7fc969e67cb6p paddle::framework::Executor::Run(paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool) + 1526
5       0x7fc969de3523p void pybind11::cpp_function::initialize<pybind11::cpp_function::initialize<void, paddle::framework::Executor, paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool, pybind11::name, pybind11::is_method, pybind11::sibling>(void (paddle::framework::Executor::*)(paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(paddle::framework::Executor*, paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool)#1}, void, paddle::framework::Executor*, paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool, pybind11::name, pybind11::is_method, pybind11::sibling>(pybind11::cpp_function::initialize<void, paddle::framework::Executor, paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool, pybind11::name, pybind11::is_method, pybind11::sibling>(void (paddle::framework::Executor::*)(paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(paddle::framework::Executor*, paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool)#1}&&, void (*)(paddle::framework::Executor*, paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call) + 579
6       0x7fc969de00a4p pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 1236
7             0x4c37edp PyEval_EvalFrameEx + 31165
8             0x4b9ab6p PyEval_EvalCodeEx + 774
9             0x4c16e7p PyEval_EvalFrameEx + 22711
10            0x4b9ab6p PyEval_EvalCodeEx + 774
11            0x4eb30fp
12            0x4e5422p PyRun_FileExFlags + 130
13            0x4e3cd6p PyRun_SimpleFileExFlags + 390
14            0x493ae2p Py_Main + 1554
15      0x7fc99cd1b830p __libc_start_main + 240
16            0x4933e9p _start + 41
@lcy-seso
Copy link
Contributor

lcy-seso commented Feb 27, 2018

多线程程序的书写规则目前还没有文档。直接看上面的程序我看不出来错误,不熟悉这一部分,请 @reyoung 帮忙看下问题,感谢。

上面程序中有一个一个和parallel_do无关的问题,下面对dynamic_lstm 的调用之前需要接一个fc。类似这里,https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/book/test_understand_sentiment.py#L105 ,是LSTM的input-to-hidden映射,和 dynamic_lstm 一起构成一个完整的LSTM单元。

lstm_h, c = fluid.layers.dynamic_lstm(input=emb, size=hid_dim, is_reverse=False)

@reyoung
Copy link
Collaborator

reyoung commented Feb 28, 2018

import paddle.fluid as fluid
import paddle.v2 as paddle

word_dict = paddle.dataset.imdb.word_dict()
print('Load Dict Done')
# vocabulary size
dict_dim = len(word_dict)

# embedding dim
emb_dim = 128

# hidden dim
hid_dim = 128

# hidden dim2
hid_dim2 = 96

# class num
class_dim = 2

data = fluid.layers.data(
    name="words", shape=[1], dtype="int64", lod_level=1)

# label data
label = fluid.layers.data(name="label", shape=[1], dtype="int64")

places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
    feat_ = pd.read_input(data)
    label_ = pd.read_input(label)
    emb = fluid.layers.embedding(input=feat_,
                                 size=[dict_dim, emb_dim],
                                 param_attr=fluid.ParamAttr(learning_rate=5.0))

    lstm_h, c = fluid.layers.dynamic_lstm(input=emb, size=hid_dim, is_reverse=False)
    lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
    fc1 = fluid.layers.fc(input=lstm_max, size=hid_dim2, act='tanh')
    prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
    cost = fluid.layers.cross_entropy(input=prediction, label=label)
    avg_cost = fluid.layers.mean(x=cost)
    acc = fluid.layers.accuracy(input=prediction, label=label)
    pd.write_output(avg_cost)
    pd.write_output(acc)

# avg_cost, prediction = pd()
avg_cost_on_each_devs, acc_on_each_devs = pd()

avg_cost = fluid.layers.mean(x=avg_cost_on_each_devs)
acc = fluid.layers.mean(x=acc_on_each_devs)

sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
BATCH_SIZE = 4

train_reader = paddle.batch(paddle.dataset.imdb.train(word_dict), batch_size=BATCH_SIZE)
test_reader = paddle.batch(paddle.dataset.imdb.test(word_dict), batch_size=BATCH_SIZE)

place = fluid.CUDAPlace(0)

exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)

exe.run(fluid.default_startup_program())
PASS_NUM = 30
for pass_id in xrange(PASS_NUM):
    for data in train_reader():
        avg_cost_np, acc_np = exe.run(fluid.default_main_program(),
                                      feed=feeder.feed(data),
                                      fetch_list=[avg_cost, acc])
        print("avg loss: {0}, Acc: {1}".format(str(avg_cost_np), str(acc_np)))

调用完Parallel.Do之后,需要用 avg_cost_on_each_devs, acc_on_each_devs = pd()获得Parallel.Do的输出。

@gmcather
Copy link
Contributor

gmcather commented Mar 1, 2018

最终版本

import paddle.fluid as fluid
import paddle.v2 as paddle
import numpy as np
import sys

def load_vocab(filename):
    vocab = {}
    with open(filename) as f:
        wid = 0
        for line in f:
            vocab[line.strip()] = wid
            wid += 1
    return vocab

word_dict = load_vocab(sys.argv[1])
word_dict["<unk>"] = len(word_dict)

#word_dict = paddle.dataset.imdb.word_dict()
print('Load Dict Done')
# vocabulary size
dict_dim = len(word_dict)

# embedding dim
emb_dim = 128

# hidden dim
hid_dim = 128

# hidden dim2
hid_dim2 = 96

# class num
class_dim = 2

data = fluid.layers.data(
    name="words", shape=[1], dtype="int64", lod_level=1)

# label data
label = fluid.layers.data(name="label", shape=[1], dtype="int64")

places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
    feat_ = pd.read_input(data)
    label_ = pd.read_input(label)
    emb = fluid.layers.embedding(input=feat_,
                                 size=[dict_dim, emb_dim],
                                 param_attr=fluid.ParamAttr(learning_rate=5.0))

    lstm_h, c = fluid.layers.dynamic_lstm(input=emb, size=hid_dim, is_reverse=False)
    lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
    fc1 = fluid.layers.fc(input=lstm_max, size=hid_dim2, act='tanh')
    prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
    cost = fluid.layers.cross_entropy(input=prediction, label=label)
    avg_cost = fluid.layers.mean(x=cost)
    acc = fluid.layers.accuracy(input=prediction, label=label)
    pd.write_output(avg_cost)
    pd.write_output(acc)

# avg_cost, prediction = pd()
avg_cost_on_each_devs, acc_on_each_devs = pd()

avg_cost = fluid.layers.mean(x=avg_cost_on_each_devs)
acc = fluid.layers.mean(x=acc_on_each_devs)

sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
BATCH_SIZE = 4

#train_reader = paddle.batch(paddle.dataset.imdb.train(word_dict), batch_size=BATCH_SIZE)
#test_reader = paddle.batch(paddle.dataset.imdb.test(word_dict), batch_size=BATCH_SIZE)
train_reader = paddle.batch(
    paddle.reader.shuffle(
        paddle.dataset.imdb.train(word_dict), buf_size=25000),
    batch_size=BATCH_SIZE)
test_reader = paddle.batch(
    paddle.reader.shuffle(
        paddle.dataset.imdb.test(word_dict), buf_size=25000),
    batch_size=BATCH_SIZE)

place = fluid.CPUPlace()

exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)

exe.run(fluid.default_startup_program())
PASS_NUM = 30
for pass_id in xrange(PASS_NUM):
    avg_cost_list, avg_acc_list = [], []
    for data in train_reader():
        avg_cost_np, acc_np = exe.run(fluid.default_main_program(),
                                      feed=feeder.feed(data),
                                      fetch_list=[avg_cost, acc])
        avg_cost_list.append(avg_cost_np)
        avg_acc_list.append(acc_np)
    print("avg loss: {0}, Acc: {1}".format(str(np.mean(avg_cost_list)), str(np.mean(avg_acc_list))))
        #print("avg loss: {0}, Acc: {1}".format(str(avg_cost_np), str(acc_np)))

@gmcather gmcather closed this as completed Mar 1, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants