Skip to content

Commit

Permalink
【pir】add tesnor to array kernel etc (#60703)
Browse files Browse the repository at this point in the history
* merge

* modfiy kernel

* modify net

* modify print
  • Loading branch information
xiaoguoguo626807 committed Jan 11, 2024
1 parent f8eff51 commit c173503
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 14 deletions.
11 changes: 11 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ std::tuple<pir::OpResult, pir::OpResult> array_to_tensor(pir::Value x,
return std::make_tuple(array_to_tensor.result(0), array_to_tensor.result(1));
}

pir::OpResult tensor_to_array(pir::Value x,
pir::Value out_grad,
int axis,
bool use_stack) {
auto tensor_to_array = ApiBuilder::Instance()
.GetBuilder()
->Build<paddle::dialect::TensorToArrayOp>(
x, out_grad, axis, use_stack);
return tensor_to_array.result(0);
}

pir::OpResult add_n_array(const std::vector<pir::Value>& inputs) {
auto inputs_combine_op =
ApiBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(inputs);
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ std::tuple<pir::OpResult, pir::OpResult> array_to_tensor(pir::Value x,
int axis,
bool use_stack);

pir::OpResult tensor_to_array(pir::Value x,
pir::Value out_grad,
int axis,
bool use_stack);

pir::OpResult add_n_array(const std::vector<pir::Value>& inputs);

pir::OpResult slice_array_dense(pir::Value input, pir::Value starts);
Expand Down
159 changes: 156 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
paddle::dialect::ArrayReadOp, paddle::dialect::ArrayWrite_Op,
paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp,
paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp,
paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp,
paddle::dialect::Increment_Op, paddle::dialect::ShapeBroadcastOp,
paddle::dialect::MemcpyD2hMultiIoOp
paddle::dialect::TensorToArrayOp, paddle::dialect::SelectInputOp,
paddle::dialect::IncrementOp, paddle::dialect::Increment_Op,
paddle::dialect::ShapeBroadcastOp, paddle::dialect::MemcpyD2hMultiIoOp
#else

#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
Expand Down Expand Up @@ -1085,6 +1085,7 @@ void SplitGradOp::Build(pir::Builder &builder,
dense_x_grad.offset());
argument_outputs.push_back(x_grad_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void SplitGradOp::Build(pir::Builder &builder,
Expand Down Expand Up @@ -1142,6 +1143,7 @@ void SplitGradOp::Build(pir::Builder &builder,
dense_x_grad.offset());
argument_outputs.push_back(x_grad_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void SplitGradOp::VerifySig() {
Expand Down Expand Up @@ -1250,6 +1252,7 @@ void CreateArrayOp::Build(pir::Builder &builder,
dense_out.layout());
argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void CreateArrayOp::VerifySig() {
Expand Down Expand Up @@ -1885,6 +1888,7 @@ void ArrayToTensorOp::Build(pir::Builder &builder, // NOLINT
dense_out_index.offset());
argument_outputs.push_back(out_index_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void ArrayToTensorOp::VerifySig() {
Expand Down Expand Up @@ -1940,6 +1944,153 @@ void ArrayToTensorOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta);
}

const char *TensorToArrayOp::attributes_name[2] = {"axis", "use_stack"};

OpInfoTuple TensorToArrayOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
paddle::dialect::OpInputInfo("x",
"paddle::dialect::DenseTensorArrayType",
false,
false,
false,
true),
paddle::dialect::OpInputInfo("out_grad",
"paddle::dialect::DenseTensorType",
false,
false,
false,
true)};

std::vector<paddle::dialect::OpAttributeInfo> attributes = {
paddle::dialect::OpAttributeInfo("axis", "pir::Int32Attribute", ""),
paddle::dialect::OpAttributeInfo("use_stack", "pir::BoolAttribute", "")};

std::vector<paddle::dialect::OpOutputInfo> outputs = {
paddle::dialect::OpOutputInfo(
"out", "paddle::dialect::DenseTensorArrayType", false, false)};

paddle::dialect::OpRunTimeInfo run_time_info =
paddle::dialect::OpRunTimeInfo("TensorToArrayInferMeta",
{"x", "axis", "use_stack"},
"tensor_to_array",
{"x", "axis", "use_stack"},
{"x"},
{},
{},
{});
return std::make_tuple(
inputs, attributes, outputs, run_time_info, "tensor_to_array");
}

void TensorToArrayOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value x_,
pir::Value out_grad_,
int axis,
bool use_stack) {
VLOG(4) << "Start build TensorToArrayOp";
VLOG(4) << "Builder construction inputs";
argument.AddInputs({x_, out_grad_});

VLOG(4) << "Builder construction attributes";
pir::Attribute attr_axis =
pir::Int32Attribute::get(pir::IrContext::Instance(), axis);
argument.AddAttribute("axis", attr_axis);
pir::Attribute attr_use_stack =
pir::BoolAttribute::get(pir::IrContext::Instance(), use_stack);
argument.AddAttribute("use_stack", attr_use_stack);

VLOG(4) << "Builder construction outputs";
paddle::dialect::DenseTensorArrayType x =
x_.type().dyn_cast<paddle::dialect::DenseTensorArrayType>();
paddle::dialect::IrTensor dense_x(
paddle::dialect::TransToPhiDataType(x.dtype()), {}, x.data_layout(), {});

paddle::dialect::DenseTensorType out_grad =
out_grad_.type().dyn_cast<paddle::dialect::DenseTensorType>();
paddle::dialect::IrTensor dense_out_grad(
paddle::dialect::TransToPhiDataType(out_grad.dtype()),
out_grad.dims(),
out_grad.data_layout(),
out_grad.lod(),
out_grad.offset());

VLOG(4) << "Builder construction meta_x, meta_out_grad";
paddle::dialect::IrMetaTensor meta_out_grad(&dense_out_grad);
paddle::dialect::IrMetaTensor meta_x(&dense_x);

paddle::dialect::IrTensor dense_x_grad;
paddle::dialect::IrMetaTensor meta_x_grad(&dense_x_grad);

phi::TensorToArrayInferMeta(
meta_x, meta_out_grad, axis, use_stack, &meta_x_grad);

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_array_type =
paddle::dialect::DenseTensorArrayType::get(
pir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_x_grad.dtype()),
dense_x_grad.layout());
argument_outputs.push_back(out_dense_tensor_array_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void TensorToArrayOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: "
"TensorToArrayOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
2u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 2.", input_size));

PADDLE_ENFORCE((*this)
->operand_source(0)
.type()
.isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
PADDLE_ENFORCE((*this)
->operand_source(1)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th input."));
}

VLOG(4) << "Verifying attributes:";
{
auto &attributes = this->attributes();
PADDLE_ENFORCE(attributes.count("axis") > 0, "axis does not exist.");
PADDLE_ENFORCE(attributes.count("use_stack") > 0,
"use_stack does not exist.");
}

VLOG(4) << "Verifying outputs:";
{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(
output_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", output_size));
PADDLE_ENFORCE(
(*this)->result(0).type().isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th output."));
}
VLOG(4) << "End Verifying for: TensorToArrayOp.";
}

void TensorToArrayOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::TensorToArrayInferMeta);
fn(infer_meta);
}

const char *SliceArrayOp::attributes_name[2] = {"starts", "ends"};

OpInfoTuple SliceArrayOp::GetOpInfo() {
Expand Down Expand Up @@ -2161,6 +2312,7 @@ void SliceArrayDenseOp::Build(pir::Builder &builder, // NOLINT
dense_out.offset());
argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void SliceArrayDenseOp::InferMeta(phi::InferMetaContext *infer_meta) {
Expand Down Expand Up @@ -3136,6 +3288,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayDenseOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::TensorToArrayOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
Expand Down
36 changes: 33 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,10 @@ class ArrayWrite_Op : public pir::Op<ArrayWrite_Op,
const std::vector<std::vector<bool>> &stop_gradients);
};

class ArrayToTensorOp
: public pir::Op<ArrayToTensorOp, OpYamlInfoInterface, InferMetaInterface> {
class ArrayToTensorOp : public pir::Op<ArrayToTensorOp,
OpYamlInfoInterface,
paddle::dialect::VjpInterface,
InferMetaInterface> {
public:
using Op::Op;
static const char *name() { return "pd_op.array_to_tensor"; }
Expand All @@ -332,7 +334,34 @@ class ArrayToTensorOp
void VerifySig();
pir::Value x() { return operand_source(0); }
pir::OpResult out() { return result(0); }
pir::OpResult out_index() { return result(2); }
pir::OpResult out_index() { return result(1); }
static void InferMeta(phi::InferMetaContext *infer_meta);
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs_,
const std::vector<std::vector<pir::Value>> &outputs,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients);
};

class TensorToArrayOp
: public pir::Op<TensorToArrayOp, OpYamlInfoInterface, InferMetaInterface> {
public:
using Op::Op;
static const char *name() { return "pd_op.tensor_to_array"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static OpInfoTuple GetOpInfo();
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value x,
pir::Value out_grad,
int axis,
bool use_stack);
void VerifySig();
pir::Value x() { return operand_source(0); }
pir::Value out_grad() { return operand_source(1); }
pir::OpResult x_grad() { return result(0); }
static void InferMeta(phi::InferMetaContext *infer_meta);
};

Expand Down Expand Up @@ -630,6 +659,7 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayDenseOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::TensorToArrayOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
Expand Down
47 changes: 46 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ std::vector<std::vector<pir::OpResult>> ArrayReadOp::Vjp(
2,
platform::errors::InvalidArgument(
"Array_read op's outputs size should be 1, but now is %d.",
outputs.size()));
out_grads.size()));

VLOG(6) << "Vjp prepare call Array_read's vjp inteface";

Expand All @@ -269,5 +269,50 @@ std::vector<std::vector<pir::OpResult>> ArrayReadOp::Vjp(
std::vector<std::vector<pir::OpResult>> res;
return res;
}

std::vector<std::vector<pir::OpResult>> ArrayToTensorOp::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs_,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
PADDLE_ENFORCE_EQ(
inputs_.size(),
1,
platform::errors::InvalidArgument(
"Array_read op's inputs size should be 1, but now is %d.",
inputs_.size()));
PADDLE_ENFORCE_EQ(
outputs.size(),
2,
platform::errors::InvalidArgument(
"Array_read op's outputs size should be 2, but now is %d.",
outputs.size()));

PADDLE_ENFORCE_EQ(
out_grads.size(),
2,
platform::errors::InvalidArgument(
"Array_read op's outputs size should be 2, but now is %d.",
out_grads.size()));

VLOG(6) << "Vjp prepare Prepare attributes of array_to_tensor_grad";
int axis = op->attribute("axis").dyn_cast<pir::Int32Attribute>().data();
bool use_stack =
op->attribute("use_stack").dyn_cast<pir::BoolAttribute>().data();

VLOG(6) << "Vjp prepare call ArrayToTensor's vjp inteface";

pir::OpResult tensor_res = paddle::dialect::tensor_to_array(
inputs_[0][0], out_grads[0][0], axis, use_stack);

std::vector<std::vector<pir::OpResult>> res(1);
res[0].resize(1);
if (!stop_gradients[0][0]) {
res[0][0] = tensor_res;
}
return res;
}

} // namespace dialect
} // namespace paddle
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ void ArrayToTensorInferMeta(const MetaTensor& x,
out_index->set_dims(common::make_ddim({-1}));
}

void TensorToArrayInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
int axis,
bool use_stack,
MetaTensor* x_grad) {
x_grad->set_dtype(x.dtype());
x_grad->set_layout(x.layout());
}

void ArgMinMaxInferMeta(const MetaTensor& x,
const Scalar& axis,
bool keepdims,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ void ArrayToTensorInferMeta(const MetaTensor& x,
MetaTensor* out_index,
MetaConfig config = MetaConfig());

void TensorToArrayInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
int axis,
bool use_stack,
MetaTensor* x_grad);

void AsRealInferMeta(const MetaTensor& input, MetaTensor* output);

void AsComplexInferMeta(const MetaTensor& input, MetaTensor* output);
Expand Down
Loading

0 comments on commit c173503

Please sign in to comment.