Skip to content

Commit

Permalink
[NPU] fix elementwise_mul to support broadcast, test=develop (PaddleP…
Browse files Browse the repository at this point in the history
…addle#36258)

* [NPU] fix elementwise_mul to support broadcast, test=develop

* remove debug files, test=develop

* add axis support, test=develop
  • Loading branch information
qili93 committed Oct 12, 2021
1 parent b3f6eed commit 09778f4
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 148 deletions.
132 changes: 93 additions & 39 deletions paddle/fluid/operators/elementwise/elementwise_mul_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,67 +12,127 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_ASCEND_CL
#include <memory>
#include <string>

#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_npu.h"
#include "paddle/fluid/operators/npu_op_runner.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
using NPUDeviceContext = platform::NPUDeviceContext;

template <typename T>
static void ReduceDims(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const int axis,
const framework::DDim& ddims,
const framework::DDim& brd_ddims, const Tensor& in,
Tensor* out) {
std::vector<int64_t> axes;
int64_t brd_size = brd_ddims.size();
int64_t org_size = ddims.size();
// int64_t diff = brd_dims.size() - dims.size();
for (int64_t i = 0; i < brd_size; ++i) {
if (i < axis || i >= org_size + axis) {
axes.push_back(i);
continue;
}
if (brd_ddims[i] > ddims[i - axis]) {
axes.push_back(i);
}
}
// LOG(INFO) << "axes = " << framework::make_ddim(axes).to_str();
out->mutable_data<T>(ctx.GetPlace());
const auto& runner = NpuOpRunner("ReduceSumD", {in}, {*out},
{{"axes", axes}, {"keep_dims", false}});
runner.Run(stream);
}

template <typename T>
class ElementwiseMulNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<NPUDeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");

auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());

int axis = ctx.Attr<int>("axis");

bool direct_compute = false;
auto x_dims = x->dims();
auto y_dims = y->dims();
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
if (x_dims.size() >= y_dims.size()) {
direct_compute = x_dims.size() == (y_dims.size() + axis);
} else {
direct_compute = y_dims.size() == (x_dims.size() + axis);
}

auto place = ctx.GetPlace();

out->mutable_data<T>(place);

auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
auto stream = ctx.template device_context<NPUDeviceContext>().stream();

const auto& runner = NpuOpRunner("Mul", {*x, *y}, {*out}, {});
runner.Run(stream);
if (direct_compute) {
const auto& runner = NpuOpRunner("Mul", {*x, *y}, {*out}, {});
runner.Run(stream);
} else {
Tensor trans_x, trans_y;
NpuElementWiseOpBroadcast<T>(dev_ctx, x, y, axis, &trans_x, &trans_y);
const auto& runner = NpuOpRunner("Mul", {trans_x, trans_y}, {*out}, {});
runner.Run(stream);
}
}
};

template <typename DeviceContext, typename T>
template <typename T>
class ElementwiseMulGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<NPUDeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));

auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");

auto place = ctx.GetPlace();
axis = (axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis);
auto stream = ctx.template device_context<NPUDeviceContext>().stream();

auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
Tensor trans_x, trans_y;
NpuElementWiseOpBroadcast<T>(dev_ctx, x, y, axis, &trans_x, &trans_y);

if (dx) {
dx->mutable_data<T>(place);
const auto& runner_dx = NpuOpRunner("Mul", {*dout, *y}, {*dx}, {});
runner_dx.Run(stream);
if (dx->dims() == dout->dims()) {
dx->mutable_data<T>(ctx.GetPlace());
const auto& runner_dx = NpuOpRunner("Mul", {*dout, trans_y}, {*dx}, {});
runner_dx.Run(stream);
} else {
Tensor dx_temp(x->type());
dx_temp.Resize(trans_x.dims());
dx_temp.mutable_data<T>(ctx.GetPlace());
const auto& runner_dx =
NpuOpRunner("Mul", {*dout, trans_y}, {dx_temp}, {});
runner_dx.Run(stream);
ReduceDims<T>(ctx, stream, axis, dx->dims(), trans_x.dims(), dx_temp,
dx);
}
}

if (dy) {
dy->mutable_data<T>(place);
const auto& runner_dy = NpuOpRunner("Mul", {*x, *dout}, {*dy}, {});
runner_dy.Run(stream);
if (dy->dims() == dout->dims()) {
dy->mutable_data<T>(ctx.GetPlace());
const auto& runner_dy = NpuOpRunner("Mul", {trans_x, *dout}, {*dy}, {});
runner_dy.Run(stream);
} else {
Tensor dy_temp(y->type());
dy_temp.Resize(trans_y.dims());
dy_temp.mutable_data<T>(ctx.GetPlace());
const auto& runner_dy =
NpuOpRunner("Mul", {trans_x, *dout}, {dy_temp}, {});
runner_dy.Run(stream);
ReduceDims<T>(ctx, stream, axis, dy->dims(), trans_y.dims(), dy_temp,
dy);
}
}
}
};
Expand All @@ -82,15 +142,9 @@ class ElementwiseMulGradNPUKernel : public framework::OpKernel<T> {

namespace ops = paddle::operators;

REGISTER_OP_NPU_KERNEL(
elementwise_mul,
ops::ElementwiseMulNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ElementwiseMulNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(elementwise_mul, ops::ElementwiseMulNPUKernel<float>,
ops::ElementwiseMulNPUKernel<paddle::platform::float16>);

REGISTER_OP_NPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ElementwiseMulGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
#endif
elementwise_mul_grad, ops::ElementwiseMulGradNPUKernel<float>,
ops::ElementwiseMulGradNPUKernel<paddle::platform::float16>);
Loading

0 comments on commit 09778f4

Please sign in to comment.