diff --git a/paddle/fluid/operators/matmul_op_npu.cc b/paddle/fluid/operators/matmul_op_npu.cc new file mode 100644 index 0000000000000..b97d0abefda5c --- /dev/null +++ b/paddle/fluid/operators/matmul_op_npu.cc @@ -0,0 +1,183 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class MatMulNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + bool transpose_x = ctx.Attr("transpose_X"); + bool transpose_y = ctx.Attr("transpose_Y"); + + if (x->dims().size() == 2) { + out->mutable_data(ctx.GetPlace()); + + const auto& runner = NpuOpRunner( + "MatMul", {*x, *y}, {*out}, + {{"transpose_x1", transpose_x}, {"transpose_x2", transpose_y}}); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + + } else if (x->dims().size() > 2) { + out->mutable_data(ctx.GetPlace()); + + const auto& runner = + NpuOpRunner("BatchMatMul", {*x, *y}, {*out}, + {{"adj_x1", transpose_x}, {"adj_x2", transpose_y}}); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } + } +}; + +template +class MatMulGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + bool transpose_y = ctx.Attr("transpose_Y"); + auto stream = + ctx.template device_context() + .stream(); + + if (x->dims().size() == 2) { + if (transpose_y) { + if (dx) { + dx->mutable_data(ctx.GetPlace()); + const auto& runner_dx = + NpuOpRunner("MatMul", {*dout, *y}, {*dx}, + {{"transpose_x1", false}, {"transpose_x2", false}}); + + runner_dx.Run(stream); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + const auto& runner_dy = + NpuOpRunner("MatMul", {*dout, *x}, {*dy}, + {{"transpose_x1", true}, {"transpose_x2", false}}); + + runner_dy.Run(stream); + } + + } else { + if (dx) { + dx->mutable_data(ctx.GetPlace()); + const auto& runner_dx = + NpuOpRunner("MatMul", {*dout, *y}, {*dx}, + {{"transpose_x1", false}, {"transpose_x2", true}}); + + runner_dx.Run(stream); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + const auto& runner_dy = + NpuOpRunner("MatMul", {*x, *dout}, {*dy}, + {{"transpose_x1", true}, {"transpose_x2", false}}); + + runner_dy.Run(stream); + } + } + } else if (x->dims().size() > 2) { + if (transpose_y) { + if (dx) { + dx->mutable_data(ctx.GetPlace()); + const auto& runner_dx = + NpuOpRunner("BatchMatMul", {*dout, *y}, {*dx}, + {{"adj_x1", false}, {"adj_x2", false}}); + + runner_dx.Run(stream); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + const auto& runner_dy = + NpuOpRunner("BatchMatMul", {*dout, *x}, {*dy}, + {{"adj_x1", true}, {"adj_x2", false}}); + + runner_dy.Run(stream); + } + } else { + if (dx) { + dx->mutable_data(ctx.GetPlace()); + const auto& runner_dx = + NpuOpRunner("BatchMatMul", {*dout, *y}, {*dx}, + {{"adj_x1", false}, {"adj_x2", true}}); + + runner_dx.Run(stream); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + if ((x->dims().size() == 3) && (dout->dims().size() == 3) && + (dy->dims().size() == 2)) { + framework::Tensor dout_; + dout_.ShareDataWith(*dout); + std::vector vec_dim = framework::vectorize(dout_.dims()); + std::vector vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]}; + dout_.Resize(framework::make_ddim(vec_dim_v)); + + framework::Tensor x_; + x_.ShareDataWith(*x); + std::vector vec_dim_x = framework::vectorize(x_.dims()); + std::vector vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1], + vec_dim_x[2]}; + x_.Resize(framework::make_ddim(vec_dim_x_v)); + const auto& runner_dy = + NpuOpRunner("MatMul", {x_, dout_}, {*dy}, + {{"transpose_x1", true}, {"transpose_x2", false}}); + runner_dy.Run(stream); + } else { + const auto& runner_dy = + NpuOpRunner("BatchMatMul", {*x, *dout}, {*dy}, + {{"adj_x1", true}, {"adj_x2", false}}); + runner_dy.Run(stream); + } + } + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + matmul, ops::MatMulNPUKernel, + ops::MatMulNPUKernel); +REGISTER_OP_NPU_KERNEL( + matmul_grad, + ops::MatMulGradNPUKernel, + ops::MatMulGradNPUKernel); diff --git a/paddle/fluid/operators/memcpy_op.h b/paddle/fluid/operators/memcpy_op.h index 63a41cc723731..ecd266858024e 100644 --- a/paddle/fluid/operators/memcpy_op.h +++ b/paddle/fluid/operators/memcpy_op.h @@ -51,17 +51,14 @@ class MemcpyFunctor { } else if (dst_place_type_ == 1) { framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor); - } + } else if (dst_place_type_ == 0) { + framework::TensorCopySync(lod_tensor, platform::CPUPlace(), &out_tensor); #ifdef PADDLE_WITH_ASCEND_CL - else if (dst_place_type_ == 0) { // NOLINT - framework::TensorCopy(lod_tensor, platform::CPUPlace(), dev_ctx_, - &out_tensor); } else if (dst_place_type_ == 4) { framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor); - } #endif - else { // NOLINT + } else { PADDLE_THROW(platform::errors::Unimplemented( "memcpy dst_place_type: %d is not supported yet.", dst_place_type_)); } diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index edc75bda4abdf..130e10a1f8de3 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -122,7 +122,8 @@ framework::OpKernelType AdamOp::GetExpectedKernelType( framework::OpKernelType AdamOp::GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const { - if (var_name == "Beta1Pow" || var_name == "Beta2Pow") { + if (var_name == "Beta1Pow" || var_name == "Beta2Pow" || + var_name == "SkipUpdate") { return expected_kernel_type; } else { return framework::OpKernelType(expected_kernel_type.data_type_, diff --git a/paddle/fluid/operators/optimizers/adam_op_npu.cc b/paddle/fluid/operators/optimizers/adam_op_npu.cc index 8b33dc64c4e4f..d0de480c1a0cc 100644 --- a/paddle/fluid/operators/optimizers/adam_op_npu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_npu.cc @@ -141,7 +141,7 @@ class AdamNPUKernel : public framework::OpKernel { if (ctx.HasInput("Beta2Tensor")) { beta2_tensor = ctx.Input("Beta2Tensor"); - PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1, + PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1, platform::errors::InvalidArgument( "Input(Beta2Tensor) size must be 1, but get %d", beta2_tensor->numel())); diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index 7a646e069db35..22eb2d20f3db7 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -400,6 +400,10 @@ def apply_gradients(self, params_grads): name="update_loss_scaling") # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow if isinstance(self._optimizer, paddle.fluid.optimizer.Adam): + # NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we + # copy it in advance to avoid multiple time copies. + found_inf = paddle.tensor.creation._memcpy(found_inf, + paddle.CPUPlace()) self._optimizer._set_auxiliary_var('found_inf', found_inf) optimize_ops = self._optimizer.apply_gradients(params_grads) return optimize_ops diff --git a/python/paddle/fluid/tests/unittests/npu/test_mixed_precision_npu.py b/python/paddle/fluid/tests/unittests/npu/test_mixed_precision_npu.py new file mode 100644 index 0000000000000..193b9eb4e0aca --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_mixed_precision_npu.py @@ -0,0 +1,30 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import paddle +sys.path.append("..") +import test_mixed_precision + +paddle.enable_static() + + +class AMPTestNpu(test_mixed_precision.AMPTest): + def setUp(self): + self.place = paddle.NPUPlace(0) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_memcpy_op.py b/python/paddle/fluid/tests/unittests/test_memcpy_op.py index 38e9379bc1667..3fecef9397c63 100755 --- a/python/paddle/fluid/tests/unittests/test_memcpy_op.py +++ b/python/paddle/fluid/tests/unittests/test_memcpy_op.py @@ -144,32 +144,6 @@ def test_SELECTED_ROWS(self): feed={}, fetch_list=[selected_row_var.name, pinned_var.name]) - def test_OTHER_PLACE_NotImplementedError(self): - main_program, pinned_var = self.get_prog() - lod_tensor_var = main_program.global_block().create_var( \ - name="lod_tensor_0", dtype="float32", persistable=False, stop_gradient=True) - main_program.global_block().append_op( - type="fill_constant", - outputs={"Out": lod_tensor_var}, - attrs={ - "shape": lod_tensor_var.shape, - "dtype": lod_tensor_var.dtype, - "value": 1.0, - "place_type": 0 - }) - main_program.global_block().append_op( - type='memcpy', - inputs={'X': pinned_var}, - outputs={'Out': lod_tensor_var}, - attrs={'dst_place_type': 0, }) - with self.assertRaises(NotImplementedError): - place = fluid.CUDAPlace(0) - exe = fluid.Executor(place) - lod_tensor_var_, pinned_ = exe.run( - main_program, - feed={}, - fetch_list=[lod_tensor_var.name, pinned_var.name]) - class TestMemcpyApi(unittest.TestCase): def test_api(self): diff --git a/python/paddle/fluid/tests/unittests/test_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_mixed_precision.py index 89d40e9314e50..57ea7ad1aa250 100644 --- a/python/paddle/fluid/tests/unittests/test_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_mixed_precision.py @@ -47,6 +47,9 @@ def forward(self, x): class AMPTest(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) + def net(self): input_size = 4096 output_size = 4096 @@ -82,7 +85,8 @@ def test_skip_update(self): fetch_list = [ loss, weight, moment1, beta_pow1, 'find_infinite_scale.tmp_0' ] - exe = paddle.static.Executor(paddle.CUDAPlace(0)) + + exe = paddle.static.Executor(self.place) train_data = [ np.random.rand(batch_size, input_size).astype(np.float32)