Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【pir】add tesnor to array kernel etc #60703

Merged
merged 5 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ std::tuple<pir::OpResult, pir::OpResult> array_to_tensor(pir::Value x,
return std::make_tuple(array_to_tensor.result(0), array_to_tensor.result(1));
}

pir::OpResult tensor_to_array(pir::Value x,
pir::Value out_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里带了grad,表示这个tensor_to_array是特指的反向API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前只有反向需要这个算子,没有提供前向api

int axis,
bool use_stack) {
auto tensor_to_array = ApiBuilder::Instance()
.GetBuilder()
->Build<paddle::dialect::TensorToArrayOp>(
x, out_grad, axis, use_stack);
return tensor_to_array.result(0);
}

pir::OpResult add_n_array(const std::vector<pir::Value>& inputs) {
auto inputs_combine_op =
ApiBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(inputs);
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ std::tuple<pir::OpResult, pir::OpResult> array_to_tensor(pir::Value x,
int axis,
bool use_stack);

pir::OpResult tensor_to_array(pir::Value x,
pir::Value out_grad,
int axis,
bool use_stack);

pir::OpResult add_n_array(const std::vector<pir::Value>& inputs);

pir::OpResult slice_array_dense(pir::Value input, pir::Value starts);
Expand Down
159 changes: 156 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
paddle::dialect::ArrayReadOp, paddle::dialect::ArrayWrite_Op,
paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp,
paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp,
paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp,
paddle::dialect::Increment_Op, paddle::dialect::ShapeBroadcastOp,
paddle::dialect::MemcpyD2hMultiIoOp
paddle::dialect::TensorToArrayOp, paddle::dialect::SelectInputOp,
paddle::dialect::IncrementOp, paddle::dialect::Increment_Op,
paddle::dialect::ShapeBroadcastOp, paddle::dialect::MemcpyD2hMultiIoOp
#else

#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
Expand Down Expand Up @@ -1085,6 +1085,7 @@ void SplitGradOp::Build(pir::Builder &builder,
dense_x_grad.offset());
argument_outputs.push_back(x_grad_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void SplitGradOp::Build(pir::Builder &builder,
Expand Down Expand Up @@ -1142,6 +1143,7 @@ void SplitGradOp::Build(pir::Builder &builder,
dense_x_grad.offset());
argument_outputs.push_back(x_grad_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void SplitGradOp::VerifySig() {
Expand Down Expand Up @@ -1250,6 +1252,7 @@ void CreateArrayOp::Build(pir::Builder &builder,
dense_out.layout());
argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void CreateArrayOp::VerifySig() {
Expand Down Expand Up @@ -1885,6 +1888,7 @@ void ArrayToTensorOp::Build(pir::Builder &builder, // NOLINT
dense_out_index.offset());
argument_outputs.push_back(out_index_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void ArrayToTensorOp::VerifySig() {
Expand Down Expand Up @@ -1940,6 +1944,153 @@ void ArrayToTensorOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta);
}

const char *TensorToArrayOp::attributes_name[2] = {"axis", "use_stack"};

OpInfoTuple TensorToArrayOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
paddle::dialect::OpInputInfo("x",
"paddle::dialect::DenseTensorArrayType",
false,
false,
false,
true),
paddle::dialect::OpInputInfo("out_grad",
"paddle::dialect::DenseTensorType",
false,
false,
false,
true)};

std::vector<paddle::dialect::OpAttributeInfo> attributes = {
paddle::dialect::OpAttributeInfo("axis", "pir::Int32Attribute", ""),
paddle::dialect::OpAttributeInfo("use_stack", "pir::BoolAttribute", "")};

std::vector<paddle::dialect::OpOutputInfo> outputs = {
paddle::dialect::OpOutputInfo(
"out", "paddle::dialect::DenseTensorArrayType", false, false)};

paddle::dialect::OpRunTimeInfo run_time_info =
paddle::dialect::OpRunTimeInfo("TensorToArrayInferMeta",
{"x", "axis", "use_stack"},
"tensor_to_array",
{"x", "axis", "use_stack"},
{"x"},
{},
{},
{});
return std::make_tuple(
inputs, attributes, outputs, run_time_info, "tensor_to_array");
}

void TensorToArrayOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value x_,
pir::Value out_grad_,
int axis,
bool use_stack) {
VLOG(4) << "Start build TensorToArrayOp";
VLOG(4) << "Builder construction inputs";
argument.AddInputs({x_, out_grad_});

VLOG(4) << "Builder construction attributes";
pir::Attribute attr_axis =
pir::Int32Attribute::get(pir::IrContext::Instance(), axis);
argument.AddAttribute("axis", attr_axis);
pir::Attribute attr_use_stack =
pir::BoolAttribute::get(pir::IrContext::Instance(), use_stack);
argument.AddAttribute("use_stack", attr_use_stack);

VLOG(4) << "Builder construction outputs";
paddle::dialect::DenseTensorArrayType x =
x_.type().dyn_cast<paddle::dialect::DenseTensorArrayType>();
paddle::dialect::IrTensor dense_x(
paddle::dialect::TransToPhiDataType(x.dtype()), {}, x.data_layout(), {});

paddle::dialect::DenseTensorType out_grad =
out_grad_.type().dyn_cast<paddle::dialect::DenseTensorType>();
paddle::dialect::IrTensor dense_out_grad(
paddle::dialect::TransToPhiDataType(out_grad.dtype()),
out_grad.dims(),
out_grad.data_layout(),
out_grad.lod(),
out_grad.offset());

VLOG(4) << "Builder construction meta_x, meta_out_grad";
paddle::dialect::IrMetaTensor meta_out_grad(&dense_out_grad);
paddle::dialect::IrMetaTensor meta_x(&dense_x);

paddle::dialect::IrTensor dense_x_grad;
paddle::dialect::IrMetaTensor meta_x_grad(&dense_x_grad);

phi::TensorToArrayInferMeta(
meta_x, meta_out_grad, axis, use_stack, &meta_x_grad);

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_array_type =
paddle::dialect::DenseTensorArrayType::get(
pir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_x_grad.dtype()),
dense_x_grad.layout());
argument_outputs.push_back(out_dense_tensor_array_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void TensorToArrayOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: "
"TensorToArrayOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
2u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 2.", input_size));

PADDLE_ENFORCE((*this)
->operand_source(0)
.type()
.isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
PADDLE_ENFORCE((*this)
->operand_source(1)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th input."));
}

VLOG(4) << "Verifying attributes:";
{
auto &attributes = this->attributes();
PADDLE_ENFORCE(attributes.count("axis") > 0, "axis does not exist.");
PADDLE_ENFORCE(attributes.count("use_stack") > 0,
"use_stack does not exist.");
}

VLOG(4) << "Verifying outputs:";
{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(
output_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", output_size));
PADDLE_ENFORCE(
(*this)->result(0).type().isa<paddle::dialect::DenseTensorArrayType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th output."));
}
VLOG(4) << "End Verifying for: TensorToArrayOp.";
}

void TensorToArrayOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::TensorToArrayInferMeta);
fn(infer_meta);
}

const char *SliceArrayOp::attributes_name[2] = {"starts", "ends"};

OpInfoTuple SliceArrayOp::GetOpInfo() {
Expand Down Expand Up @@ -2161,6 +2312,7 @@ void SliceArrayDenseOp::Build(pir::Builder &builder, // NOLINT
dense_out.offset());
argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void SliceArrayDenseOp::InferMeta(phi::InferMetaContext *infer_meta) {
Expand Down Expand Up @@ -3136,6 +3288,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayDenseOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::TensorToArrayOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
Expand Down
36 changes: 33 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,10 @@ class ArrayWrite_Op : public pir::Op<ArrayWrite_Op,
const std::vector<std::vector<bool>> &stop_gradients);
};

class ArrayToTensorOp
: public pir::Op<ArrayToTensorOp, OpYamlInfoInterface, InferMetaInterface> {
class ArrayToTensorOp : public pir::Op<ArrayToTensorOp,
OpYamlInfoInterface,
paddle::dialect::VjpInterface,
InferMetaInterface> {
public:
using Op::Op;
static const char *name() { return "pd_op.array_to_tensor"; }
Expand All @@ -332,7 +334,34 @@ class ArrayToTensorOp
void VerifySig();
pir::Value x() { return operand_source(0); }
pir::OpResult out() { return result(0); }
pir::OpResult out_index() { return result(2); }
pir::OpResult out_index() { return result(1); }
static void InferMeta(phi::InferMetaContext *infer_meta);
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &inputs_,
const std::vector<std::vector<pir::Value>> &outputs,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients);
};

class TensorToArrayOp
: public pir::Op<TensorToArrayOp, OpYamlInfoInterface, InferMetaInterface> {
public:
using Op::Op;
static const char *name() { return "pd_op.tensor_to_array"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static OpInfoTuple GetOpInfo();
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value x,
pir::Value out_grad,
int axis,
bool use_stack);
void VerifySig();
pir::Value x() { return operand_source(0); }
pir::Value out_grad() { return operand_source(1); }
pir::OpResult x_grad() { return result(0); }
static void InferMeta(phi::InferMetaContext *infer_meta);
};

Expand Down Expand Up @@ -630,6 +659,7 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayDenseOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::TensorToArrayOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
Expand Down
47 changes: 46 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ std::vector<std::vector<pir::OpResult>> ArrayReadOp::Vjp(
2,
platform::errors::InvalidArgument(
"Array_read op's outputs size should be 1, but now is %d.",
outputs.size()));
out_grads.size()));

VLOG(6) << "Vjp prepare call Array_read's vjp inteface";

Expand All @@ -270,5 +270,50 @@ std::vector<std::vector<pir::OpResult>> ArrayReadOp::Vjp(
std::vector<std::vector<pir::OpResult>> res;
return res;
}

std::vector<std::vector<pir::OpResult>> ArrayToTensorOp::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs_,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
PADDLE_ENFORCE_EQ(
inputs_.size(),
1,
platform::errors::InvalidArgument(
"Array_read op's inputs size should be 1, but now is %d.",
inputs_.size()));
PADDLE_ENFORCE_EQ(
outputs.size(),
2,
platform::errors::InvalidArgument(
"Array_read op's outputs size should be 2, but now is %d.",
outputs.size()));

PADDLE_ENFORCE_EQ(
out_grads.size(),
2,
platform::errors::InvalidArgument(
"Array_read op's outputs size should be 2, but now is %d.",
out_grads.size()));

VLOG(6) << "Vjp prepare Prepare attributes of array_to_tensor_grad";
int axis = op->attribute("axis").dyn_cast<pir::Int32Attribute>().data();
bool use_stack =
op->attribute("use_stack").dyn_cast<pir::BoolAttribute>().data();

VLOG(6) << "Vjp prepare call ArrayToTensor's vjp inteface";

pir::OpResult tensor_res = paddle::dialect::tensor_to_array(
inputs_[0][0], out_grads[0][0], axis, use_stack);

std::vector<std::vector<pir::OpResult>> res(1);
res[0].resize(1);
if (!stop_gradients[0][0]) {
res[0][0] = tensor_res;
}
return res;
}

} // namespace dialect
} // namespace paddle
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ void ArrayToTensorInferMeta(const MetaTensor& x,
out_index->set_dims(common::make_ddim({-1}));
}

void TensorToArrayInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
int axis,
bool use_stack,
MetaTensor* x_grad) {
x_grad->set_dtype(x.dtype());
x_grad->set_layout(x.layout());
}

void ArgMinMaxInferMeta(const MetaTensor& x,
const Scalar& axis,
bool keepdims,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ void ArrayToTensorInferMeta(const MetaTensor& x,
MetaTensor* out_index,
MetaConfig config = MetaConfig());

void TensorToArrayInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
int axis,
bool use_stack,
MetaTensor* x_grad);

void AsRealInferMeta(const MetaTensor& input, MetaTensor* output);

void AsComplexInferMeta(const MetaTensor& input, MetaTensor* output);
Expand Down
Loading