Skip to content

Commit

Permalink
【PTen】Add dot and matmul grad kernel in pten (#38713)
Browse files Browse the repository at this point in the history
* refactor matmul directory in pten

* fix merge conflict

* add dot_grad kernel

* add dot_grad kernel in pten

* add matmul_grad kernel

* update the code

* delete useless code in fluid

* fix some bug of running matmul grad kernel

* fix merge conflict

* refactor some code

* refactor code
  • Loading branch information
zyfncg committed Jan 11, 2022
1 parent 5b940c4 commit be81771
Show file tree
Hide file tree
Showing 40 changed files with 3,336 additions and 2,467 deletions.
3 changes: 3 additions & 0 deletions cmake/pten_kernel.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ function(kernel_library TARGET)
endif()

list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h)
list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h)
endif()
list(APPEND all_srcs ${common_srcs})
list(APPEND all_srcs ${cpu_srcs})
list(APPEND all_srcs ${gpu_srcs})
Expand Down
26 changes: 22 additions & 4 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1880,16 +1880,32 @@ void OperatorWithKernel::BuildPtenKernelContext(
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset], out_def,
auto* buffer_tensor =
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(start_idx +
offset));
offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(outs_vector[offset],
out_def, buffer_tensor);
}
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(outs_vector[offset],
out_def));
}
}

// Deal with the case that some outputs are NULL when run the kernel.
// For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
if (outs_vector.empty()) {
if (current_vector_size > start_idx) {
pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr});
}
end_idx = start_idx + 1;
}

pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx),
i);
}
Expand Down Expand Up @@ -2002,7 +2018,9 @@ void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const {
range_pair.first, range_pair.second);

for (size_t j = 0; j < pten_outs.size(); ++j) {
experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]);
if (pten_outs[j]) {
experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]);
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/pten_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ KernelSignatureMap& KernelSignatureMap::Instance() {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto != nullptr) {
op_proto) {
KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
auto success = kernel_signature_map_->map_
Expand Down
66 changes: 47 additions & 19 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,19 +338,41 @@ static void BuildDygraphPtenKernelContext(

for (size_t i = 0; i < output_names.size(); ++i) {
auto& out_def = output_defs.at(i);
auto& outs_vector = outs.at(output_names[i]);

size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);
size_t end_idx = start_idx + outs_vector.size();
auto current_vector_size = kernel_ctx->OutputsSize();

auto iter = outs.find(output_names[i]);
if (iter == outs.end()) {
if (current_vector_size > start_idx) {
kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1),
i);
continue;
}

auto& outs_vector = iter->second;
size_t end_idx = start_idx + outs_vector.size();

// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset]->MutableVar(), out_def,
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset));
auto* buffer_tensor =
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset]->MutableVar(), out_def, buffer_tensor);
} else {
kernel_ctx->SetOutputWithoutSetRange(
start_idx + offset,
experimental::MakePtenTensorBaseFromVar(
outs_vector[offset]->MutableVar(), out_def));
}
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(
Expand Down Expand Up @@ -465,15 +487,18 @@ static void WriteBackToOutputs(
auto& output_names = std::get<2>(pt_kernel_signature.args);

for (size_t i = 0; i < output_names.size(); ++i) {
auto& outs_vector = outs.at(output_names[i]);
auto iter = outs.find(output_names[i]);
if (iter != outs.end()) {
auto& outs_vector = iter->second;

auto& range_pair = kernel_ctx->OutputRangeAt(i);
auto pten_outs = kernel_ctx->MutableOutputBetween<pten::DenseTensor>(
range_pair.first, range_pair.second);
auto& range_pair = kernel_ctx->OutputRangeAt(i);
auto pten_outs = kernel_ctx->MutableOutputBetween<pten::DenseTensor>(
range_pair.first, range_pair.second);

for (size_t j = 0; j < pten_outs.size(); ++j) {
experimental::MakeVariableFromPtenTensor(pten_outs[j],
outs_vector[j]->MutableVar());
for (size_t j = 0; j < pten_outs.size(); ++j) {
experimental::MakeVariableFromPtenTensor(pten_outs[j],
outs_vector[j]->MutableVar());
}
}
}
}
Expand Down Expand Up @@ -529,6 +554,7 @@ static void PreparedOpRunImpl(
template <typename VarType>
static void PreparedOpRunPtImpl(
const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
Expand Down Expand Up @@ -558,17 +584,19 @@ static void PreparedOpRunPtImpl(
pt_kernel_context->ClearData();

// TODO(chenweihang): add debug flags later
// TODO(chenweihang): deal with complex cases later
if (framework::IsComplexType(kernel_type.data_type_)) {
HandleComplexGradToRealGrad<VarType>(outs);
}
}

void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, pt_kernel_signature_, pt_kernel_,
pt_kernel_context_, dev_ctx_, ins, outs, attrs,
default_attrs);
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_,
pt_kernel_, pt_kernel_context_, dev_ctx_, ins,
outs, attrs, default_attrs);
} else {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs, default_attrs);
Expand All @@ -580,9 +608,9 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>(op_, pt_kernel_signature_, pt_kernel_,
pt_kernel_context_, dev_ctx_, ins,
outs, attrs, default_attrs);
PreparedOpRunPtImpl<VariableWrapper>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, pt_kernel_context_,
dev_ctx_, ins, outs, attrs, default_attrs);
} else {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
ins, outs, attrs, default_attrs);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/conj_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ConjKernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);

// call new kernel
pten::Conj<T>(dev_ctx, *pt_x.get(), pt_out.get());
pten::ConjKernel<T>(dev_ctx, *pt_x.get(), pt_out.get());
}
};

Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/operators/dot_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ class DotGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"dot_grad", {"X", "Y", framework::GradVarName("Out")}, {},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
};

template <typename T>
Expand Down
Loading

0 comments on commit be81771

Please sign in to comment.