Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fetch_op_handle #10454

Merged
merged 5 commits into from
May 9, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/framework/details/fetch_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ void FetchOpHandle::RunImpl() {
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(cpu_ctx);
if (var->generated_op_) {
var->generated_op_->Wait(cpu_ctx);
}
}
tensors_.resize(inputs_.size());
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ void NCCLAllReduceOpHandle::RunImpl() {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctxes_[p]);
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]);
}
}

auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/details/send_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ void SendOpHandle::RunImpl() {
if (in->DebugString() == "dummy") { // HACK
continue;
}
in->generated_op_->Wait(dev_ctxes_[p]);
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]);
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
Expand Down
114 changes: 95 additions & 19 deletions python/paddle/fluid/tests/unittests/test_parallel_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy
import numpy as np
import unittest

import paddle.fluid as fluid
Expand Down Expand Up @@ -243,7 +243,7 @@ def run_executor(exe, feed, fetch_list, program=None):
begin = time.time()
first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
first_loss = numpy.array(first_loss)
first_loss = np.array(first_loss)

for i in xrange(iter):
run_executor(exe=exe, feed=feed_dict, fetch_list=[])
Expand All @@ -256,7 +256,7 @@ def run_executor(exe, feed, fetch_list, program=None):
print "%.4f Instance per second" % (
(batch_size * iter + 2) / (end - begin))

last_loss = numpy.array(last_loss)
last_loss = np.array(last_loss)

print first_loss, last_loss
# self.assertGreater(first_loss[0], last_loss[0])
Expand Down Expand Up @@ -284,8 +284,8 @@ def check_simple_fc_convergence(self):
self.check_network_convergence(simple_fc_net)
self.check_network_convergence(simple_fc_net, allow_op_delay=True)

img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64')
img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence(
simple_fc_net, feed_dict={"image": img,
"label": label})
Expand All @@ -294,8 +294,8 @@ def test_simple_fc(self):
self.check_simple_fc_convergence()

def check_simple_fc_parallel_accuracy(self):
img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64')
img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
single_first_loss, single_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1000,
Expand All @@ -319,8 +319,8 @@ def test_simple_fc_parallel_accuracy(self):

def check_batchnorm_fc_convergence(self):
self.check_network_convergence(fc_with_batchnorm)
img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64')
img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence(
fc_with_batchnorm, feed_dict={"image": img,
"label": label})
Expand Down Expand Up @@ -404,9 +404,6 @@ class ModelHyperParams(object):
dropout = 0.1


import numpy as np


def prepare_batch_input(insts, src_pad_idx, trg_pad_idx, n_head):
"""
Pad the instances to the max sequence length in batch, and generate the
Expand Down Expand Up @@ -533,9 +530,8 @@ def check_network_convergence(self):
opt.minimize(loss)

batch_size = 32
image = numpy.random.normal(size=(batch_size,
784)).astype('float32')
label = numpy.random.randint(0, 10, (batch_size, 1), dtype="int64")
image = np.random.normal(size=(batch_size, 784)).astype('float32')
label = np.random.randint(0, 10, (batch_size, 1), dtype="int64")

place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
Expand All @@ -552,12 +548,12 @@ def check_network_convergence(self):

for i in xrange(5):
test_loss, = test_exe.run([loss.name], feed=feed_dict)
test_loss = numpy.array(test_loss)
test_loss = np.array(test_loss)

train_loss, = train_exe.run([loss.name], feed=feed_dict)
train_loss = numpy.array(train_loss)
train_loss = np.array(train_loss)
self.assertTrue(
numpy.allclose(
np.allclose(
train_loss, test_loss, atol=1e-8),
"Train loss: " + str(train_loss) + "\n Test loss:" +
str(test_loss))
Expand Down Expand Up @@ -712,7 +708,7 @@ def check_network_convergence(self, is_sparse):
data = train_data()
for i in xrange(10):
cur_batch = next(data)
print map(numpy.array,
print map(np.array,
pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name]))[0]

Expand All @@ -721,3 +717,83 @@ def test_update_sparse_parameter(self):

def test_update_dense_parameter(self):
self.check_network_convergence(is_sparse=False)


# test fetch all the variables of global_block

import paddle.dataset.flowers as flowers


def lenet(data, class_dim):
conv1 = fluid.layers.conv2d(data, 32, 5, 1, act=None)
bn1 = fluid.layers.batch_norm(conv1, act='relu')
pool1 = fluid.layers.pool2d(bn1, 2, 'max', 2)
conv2 = fluid.layers.conv2d(pool1, 50, 5, 1, act=None)
bn2 = fluid.layers.batch_norm(conv2, act='relu')
pool2 = fluid.layers.pool2d(bn2, 2, 'max', 2)

fc1 = fluid.layers.fc(pool2, size=500, act='relu')
fc2 = fluid.layers.fc(fc1, size=class_dim, act='softmax')

return fc2


class TestFetchOp(unittest.TestCase):
def parallel_exe(self, train_inputs, seed):
main = fluid.Program()
startup = fluid.Program()
startup.random_seed = seed
with fluid.program_guard(main, startup):
data = fluid.layers.data(
name='image', shape=[3, 224, 224], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = lenet(data, class_dim=102)
loss = fluid.layers.cross_entropy(input=out, label=label)
loss = fluid.layers.mean(loss)

opt = fluid.optimizer.Momentum(
learning_rate=0.1,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))

opt.minimize(loss)

# TODO(zcd): I found that onece the memory optimizer is open,
# parallel_exe doesn't fetch some variable, such as conv2d_0.b_0@GRAD,
# conv2d_1.b_0@GRAD. Those variables should not be pruned.
# fluid.memory_optimize(main)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!
Recently, qingqing reported that se-resnext accuracy improved after memory_optimize is turned off.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this proves that the memory_optimize have some problems.


place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup)

feeder = fluid.DataFeeder(place=place, feed_list=[data, label])
pe = fluid.ParallelExecutor(
use_cuda=True, loss_name=loss.name, main_program=main)

fetch_list = []
all_vars = main.global_block().vars
for k, v in all_vars.iteritems():
if 'velocity' not in k:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is velocity special?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

velocity also can be fetched, I only think velocity is unnecessary to be fetched.

fetch_list.append(k)

for data in train_inputs:
ret = pe.run(fetch_list, feed=feeder.feed(data))
for i in range(len(fetch_list)):
print("%s - %s" % (fetch_list[i], np.sum(ret[i])))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to avoiding printing a lot in tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will fix this.


def test_update_sparse_parameter(self):
tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16)
tst_reader_iter = tst_reader()

seed = 100
iters = 4
train_inputs = []
for i in range(iters):
train_inputs.append(tst_reader_iter.next())

self.parallel_exe(train_inputs, seed)


if __name__ == '__main__':
unittest.main()