Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pten] add concat pten kernel #38955

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1902,6 +1902,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
std::type_index(typeid(std::string))) {
pt_kernel_context->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int))) {
pt_kernel_context->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(int, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ static void BuildDygraphPtenKernelContext(
std::type_index(typeid(std::string))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(int, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
Expand Down
15 changes: 13 additions & 2 deletions paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ limitations under the License. */
#include <string>
#include <vector>

#include "paddle/pten/kernels/funcs/concat_funcs.h"

#ifdef PADDLE_WITH_MKLDNN
#include <paddle/fluid/platform/mkldnn_helper.h>
#endif
Expand Down Expand Up @@ -56,8 +58,8 @@ class ConcatOp : public framework::OperatorWithKernel {
size_t axis =
ComputeAxis(static_cast<int64_t>(ctx->Attrs().Get<int>("axis")),
static_cast<int64_t>(inputs_dims[0].size()));
framework::DDim out_dims =
ComputeAndCheckShape(ctx->IsRuntime(), inputs_dims, axis);
framework::DDim out_dims = pten::funcs::ComputeAndCheckShape(
ctx->IsRuntime(), inputs_dims, axis);
if (out_dims[axis] < 0) {
out_dims[axis] = -1;
}
Expand Down Expand Up @@ -102,6 +104,15 @@ class ConcatOp : public framework::OperatorWithKernel {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (ctx.HasInput("AxisTensor")) {
return framework::KernelSignature("concat", {"X"}, {"AxisTensor"},
{"Out"});
}
return framework::KernelSignature("concat", {"X"}, {"axis"}, {"Out"});
}
};

class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down
111 changes: 11 additions & 100 deletions paddle/fluid/operators/concat_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,54 +22,11 @@ limitations under the License. */
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/utils.h"

#include "paddle/pten/kernels/concat_kernel.h"
#include "paddle/pten/kernels/funcs/concat_funcs.h"

namespace paddle {
namespace operators {
static inline framework::DDim ComputeAndCheckShape(
const bool is_runtime, const std::vector<framework::DDim>& inputs_dims,
const size_t axis) {
const size_t n = inputs_dims.size();
auto out_dims = inputs_dims[0];
size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
PADDLE_ENFORCE_EQ(inputs_dims[i].size(), out_dims.size(),
platform::errors::InvalidArgument(
"The shape of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s].",
i, inputs_dims[0], i, inputs_dims[i]));
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
if (is_runtime) {
out_dims[axis] += inputs_dims[i][j];
} else {
if (inputs_dims[i][j] == -1 || out_dims[j] == -1) {
out_dims[axis] = -1;
} else {
out_dims[axis] += inputs_dims[i][j];
}
}
} else {
bool check_shape =
is_runtime || (inputs_dims[0][j] > 0 && inputs_dims[i][j] > 0);
if (check_shape) {
// check all shape in run time
PADDLE_ENFORCE_EQ(inputs_dims[0][j], inputs_dims[i][j],
platform::errors::InvalidArgument(
"The %d-th dimension of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s].",
j, i, inputs_dims[0], i, inputs_dims[i]));
}
if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) {
out_dims[j] = inputs_dims[i][j];
}
}
}
}
return out_dims;
}

static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -109,67 +66,21 @@ class ConcatKernel : public framework::OpKernel<T> {
ins_dims[i] = ins[i]->dims();
}

framework::DDim out_dims = ComputeAndCheckShape(true, ins_dims, axis);
framework::DDim out_dims =
pten::funcs::ComputeAndCheckShape(true, ins_dims, axis);
out->Resize(out_dims);
}
auto place = ctx.GetPlace();
out->mutable_data<T>(place);

// If axis is 0, the lod of the output is not the same as inputs.
if (axis == 0 && ins[0]->lod().size() > 0) {
size_t lod_size_0 = ins[0]->lod().size();
size_t lod_size = lod_size_0;
for (size_t i = 1; i < ins.size(); ++i) {
if (ins[i]->lod().size() > 0) {
PADDLE_ENFORCE_EQ(
ins[i]->lod().size(), lod_size_0,
platform::errors::Unimplemented(
"The lod level of all input LoDTensors should be same. "
"Maybe different lod level of input LoDTensors can concat,"
"it is not supported currently. The lod level of %dth input "
"is %d and first input is %d.",
i, ins[i]->lod().size(), lod_size_0));
} else {
lod_size = 0;
break;
}
}
if (lod_size) {
auto* out_lod = out->mutable_lod();
for (size_t i = 1; i < ins.size(); ++i) {
auto in_lod = ConvertToLengthBasedLoD(ins[i]->lod());
AppendLoD(out_lod, in_lod);
}
}
// call new kernel
auto& dev_ctx = ctx.device_context<DeviceContext>();
std::vector<pten::DenseTensor> pt_ins;
for (auto& in : ins) {
pt_ins.push_back(*in);
}

// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && ins.size() < 10) {
size_t output_offset = 0;
for (auto* in : ins) {
if (!in || in->numel() == 0UL) {
continue;
}
auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
} else {
std::vector<framework::Tensor> inputs;
for (size_t j = 0; j < ins.size(); ++j) {
if (ins[j] && ins[j]->numel() > 0) {
inputs.push_back(*ins[j]);
} else {
continue;
}
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
concat_functor(dev_ctx, inputs, static_cast<int>(axis), out);
}
pten::ConcatKernel<T>(dev_ctx, pt_ins, axis, out);
}
};

Expand Down
81 changes: 11 additions & 70 deletions paddle/fluid/operators/math/concat_and_split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/concat_and_split.h"

#include "paddle/pten/kernels/cpu/concat_and_split.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议直接把concat_and_split.cc迁移过来,马上我们也要把它移过来了,可以下个PR再做

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#endif
Expand Down Expand Up @@ -42,36 +44,9 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
size_t num = input.size();

int64_t rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int64_t out_rows = rows, out_cols = 0;

std::vector<int64_t> input_cols(input.size());
for (size_t i = 0; i < num; ++i) {
int64_t t_cols = input[i].numel() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
auto cpu_place = context.GetPlace();

// computation
auto output_data = output->data<T>();
int64_t col_idx = 0;
for (size_t j = 0; j < num; ++j) {
int64_t col_len = input_cols[j];
auto input_data = input[j].data<T>();
for (int64_t k = 0; k < out_rows; ++k) {
memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place,
input_data + k * col_len, sizeof(T) * col_len);
}
col_idx += col_len;
}
std::vector<pten::DenseTensor> pt_input{input.begin(), input.end()};
pten::ConcatImpl<T, platform::CPUDeviceContext>(context, pt_input, axis,
output);
}
};

Expand All @@ -86,46 +61,12 @@ class SplitFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
// NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
// tensors of shape [0,1,4]
if (input.numel() == 0) {
return;
}

// TODO(zcd): Add input data validity checking
size_t num = outputs->size();

int input_rows = 1;
auto dim_0 = ref_inputs[0]->dims();
for (int i = 0; i < axis; ++i) {
input_rows *= dim_0[i];
}

int input_cols = 0;

std::vector<int64_t> output_cols(outputs->size());
for (size_t i = 0; i < num; ++i) {
int t_cols = ref_inputs[i]->numel() / input_rows;
input_cols += t_cols;
output_cols[i] = t_cols;
}
auto cpu_place = context.GetPlace();

// computation
for (int k = 0; k < input_rows; ++k) {
const T* src_ptr = input.data<T>() + k * input_cols;
int col_idx = 0;
for (size_t j = 0; j < num; ++j) {
int col_len = output_cols[j];
auto* out_tensor = outputs->at(j);
if (out_tensor != nullptr) {
T* dst_ptr = out_tensor->data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
sizeof(T) * col_len);
}
col_idx += col_len;
}
}
std::vector<const pten::DenseTensor*> pt_ref_inputs{ref_inputs.begin(),
ref_inputs.end()};
std::vector<pten::DenseTensor*> pt_outputs{outputs->begin(),
outputs->end()};
pten::SplitImpl<T, platform::CPUDeviceContext>(
context, input, pt_ref_inputs, axis, &pt_outputs);
}
};

Expand Down
Loading