Skip to content

Commit

Permalink
[NPU] fix cast op (#32121)
Browse files Browse the repository at this point in the history
* fix npu kernel of cast op to handle casting to same dtype

* add comments
  • Loading branch information
zhiqiu committed Apr 7, 2021
1 parent 4638fe9 commit 78959a3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 18 deletions.
44 changes: 26 additions & 18 deletions paddle/fluid/operators/cast_op_npu.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_ASCEND_CL
#include <memory>
#include <string>

Expand Down Expand Up @@ -41,52 +40,61 @@ class CastNPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
int dtype = ctx.Attr<int>("out_dtype");

auto* out = ctx.Output<Tensor>("Out");

auto place = ctx.GetPlace();

auto iter = DTYPE_2_ACL_DTYPE.find(static_cast<framework::proto::VarType::Type>(dtype));

if (x->type() == dtype) {
// NOTE(zhiqiu): NPU cast op may result in wrong value, so
// add special case here.
VLOG(4) << "cast to same dtype:" << dtype;
out->mutable_data(place, x->type());
framework::TensorCopy(
*x, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), out);
return;
}

auto iter = DTYPE_2_ACL_DTYPE.find(
static_cast<framework::proto::VarType::Type>(dtype));
int aclDtype = iter->second;

if (dtype == framework::proto::VarType::FP32) {
out->mutable_data<float>(place);
out->mutable_data<float>(place);
} else if (dtype == framework::proto::VarType::FP16) {
out->mutable_data<paddle::platform::float16>(place);
out->mutable_data<paddle::platform::float16>(place);
} else if (dtype == framework::proto::VarType::INT16) {
out->mutable_data<int16_t>(place);
out->mutable_data<int16_t>(place);
} else if (dtype == framework::proto::VarType::INT32) {
out->mutable_data<int32_t>(place);
out->mutable_data<int32_t>(place);
} else if (dtype == framework::proto::VarType::INT64) {
out->mutable_data<int64_t>(place);
out->mutable_data<int64_t>(place);
} else if (dtype == framework::proto::VarType::FP64) {
out->mutable_data<double>(place);
out->mutable_data<double>(place);
} else if (dtype == framework::proto::VarType::BOOL) {
out->mutable_data<bool>(place);
out->mutable_data<bool>(place);
}

auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();

auto runner = NpuOpRunner("Cast", {*x}, {*out}, {{"dst_type", static_cast<int32_t>(aclDtype)}});
auto runner = NpuOpRunner("Cast", {*x}, {*out},
{{"dst_type", static_cast<int32_t>(aclDtype)}});
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddleaclDtype
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_NPU_KERNEL(
cast,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int16_t>,
cast, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int16_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int32_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, bool>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, double>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
#endif
paddle::platform::float16>);
24 changes: 24 additions & 0 deletions python/paddle/fluid/tests/unittests/npu/test_cast_op_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def set_npu(self):
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)


class TestCast2(OpTest):
def setUp(self):
self.set_npu()
Expand All @@ -71,5 +72,28 @@ def set_npu(self):
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3)


class TestCast3(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "cast"
self.place = paddle.NPUPlace(0)

ipt = np.random.random(size=[10, 10]) + 1
self.inputs = {'X': ipt.astype('int32')}
self.outputs = {'Out': ipt.astype('int32')}

self.attrs = {
'in_dtype': int(core.VarDesc.VarType.INT32),
'out_dtype': int(core.VarDesc.VarType.INT32)
}

def set_npu(self):
self.__class__.use_npu = True

def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3)


if __name__ == '__main__':
unittest.main()

0 comments on commit 78959a3

Please sign in to comment.