Skip to content

Commit

Permalink
fix error message in broadcast/allreduce/gather (#27302)
Browse files Browse the repository at this point in the history
* fix error message
  • Loading branch information
ForFishes committed Sep 16, 2020
1 parent 4f9d652 commit c296618
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 6 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/operators/distributed_ops/allreduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
}
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/distributed_ops/broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ template <typename T>
class BroadcastOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW("Broadcast op can run on gpu place only for now.");
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Broadcast op can run on gpu place only for now."));
}
};

Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
<< " From " << root_dev_id << " to " << dev_id;

if (ctx.Attr<bool>("sync_mode")) {
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
}
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};
Expand Down
17 changes: 15 additions & 2 deletions paddle/fluid/operators/gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,21 @@ class GatherOp : public framework::OperatorWithKernel {
"Output(Out) of GatherOp should not be null."));

auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE(index_dims.size() == 1 ||
(index_dims.size() == 2 && index_dims[1] == 1));

if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(
index_dims[1], 1,
platform::errors::InvalidArgument(
"The last dim of index should be 1 when it is 2D, but we get %d",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(), 1,
platform::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}

int batch_size = ctx->GetInputDim("Index")[0];
framework::DDim output_dims(ctx->GetInputDim("X"));
output_dims[0] = batch_size;
Expand Down
38 changes: 38 additions & 0 deletions python/paddle/fluid/tests/unittests/test_broadcast_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core


class TestBroadcastOpCpu(OpTest):
def setUp(self):
self.op_type = "broadcast"
input = np.random.random((100, 2)).astype("float32")
np_out = input[:]
self.inputs = {"X": input}
self.attrs = {"sync_mode": False, "root": 0}
self.outputs = {"Out": np_out}

def test_check_output_cpu(self):
try:
self.check_output_with_place(place=core.CPUPlace())
except:
print("do not support cpu test, skip")


if __name__ == "__main__":
unittest.main()

0 comments on commit c296618

Please sign in to comment.