Skip to content

Commit

Permalink
[PTen] Remove cached kernel context (PaddlePaddle#38953)
Browse files Browse the repository at this point in the history
* remove cached kernel context

* revert dataloader format change
  • Loading branch information
chenwhql committed Jan 15, 2022
1 parent 1053b1d commit 35d2b71
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 236 deletions.
9 changes: 5 additions & 4 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,16 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
VLOG(4) << "Run pten kernel: " << op->Type();
VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
<< &instr_node.DeviceContext();
pten::KernelContext pt_kernel_context;
op_with_kernel->BuildPtenKernelContext(
*instr_node.InnerRuntimeContext().get(),
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()));
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()),
&pt_kernel_context);

(*instr_node.PtenKernel())(instr_node.PtenKernelContext());
(*instr_node.PtenKernel())(&pt_kernel_context);

op_with_kernel->WriteBackToOutputs(
instr_node.InnerRuntimeContext().get());
instr_node.PtenKernelContext()->ClearData();
instr_node.InnerRuntimeContext().get(), &pt_kernel_context);
} else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
}
Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/framework/new_executor/interpretercore_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,14 @@ void build_op_func_list(const platform::Place& place,
}

if (run_pten_kernel) {
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx);
pten::KernelContext pt_kernel_context;
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx,
&pt_kernel_context);
op_func_node.pt_kernel_ = op_with_kernel->PtenKernel();
op_func_node.pt_kernel_context_ = op_with_kernel->PtenKernelContext();

(*op_func_node.pt_kernel_)(op_func_node.pt_kernel_context_);
op_with_kernel->WriteBackToOutputs(&runtime_context);
op_func_node.pt_kernel_context_->ClearData();
(*op_func_node.pt_kernel_)(&pt_kernel_context);
op_with_kernel->WriteBackToOutputs(&runtime_context,
&pt_kernel_context);
} else {
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx);
Expand Down
4 changes: 0 additions & 4 deletions paddle/fluid/framework/new_executor/new_executor_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,10 +688,6 @@ pten::Kernel* Instruction::PtenKernel() const {
return op_func_node_.pt_kernel_;
}

pten::KernelContext* Instruction::PtenKernelContext() const {
return op_func_node_.pt_kernel_context_;
}

OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }

OperatorBase* Instruction::OpBase() const {
Expand Down
5 changes: 1 addition & 4 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,7 @@ struct OpFuncNode {
platform::DeviceContext* dev_ctx_; // not owned

// fit for pten kernel
pten::Kernel* pt_kernel_{nullptr}; // not owned
pten::KernelContext* pt_kernel_context_{nullptr}; // not onwed
pten::Kernel* pt_kernel_{nullptr}; // not owned

OpFuncType type_;
};
Expand All @@ -322,8 +321,6 @@ class Instruction {

pten::Kernel* PtenKernel() const;

pten::KernelContext* PtenKernelContext() const;

OpFuncType KernelType() const;

OperatorBase* OpBase() const;
Expand Down
122 changes: 34 additions & 88 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1192,13 +1192,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp);
if (run_pten_kernel_) {
if (pt_kernel_context_ == nullptr) {
pt_kernel_context_.reset(new pten::KernelContext());
}
BuildPtenKernelContext(*runtime_ctx, dev_ctx);
(*pt_kernel_)(pt_kernel_context_.get());
WriteBackToOutputs(runtime_ctx);
pt_kernel_context_->ClearData();
pten::KernelContext pt_kernel_context;
BuildPtenKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
(*pt_kernel_)(&pt_kernel_context);
WriteBackToOutputs(runtime_ctx, &pt_kernel_context);
} else {
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
Expand Down Expand Up @@ -1791,18 +1788,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
}

void OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const {
if (pt_kernel_context_ == nullptr) {
pt_kernel_context_.reset(new pten::KernelContext());
}
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
// 2. the dispensbale, duplicable input and output
// 3. needless attributes remove
// 4. use pt Tensor directly
// 5. kernel input is not DenseTensor
pt_kernel_context_->SetDeviceContext(dev_ctx);
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx,
pten::KernelContext* pt_kernel_context) const {
pt_kernel_context->SetDeviceContext(dev_ctx);

auto& input_names = std::get<0>(pt_kernel_signature_->args);
auto& attr_names = std::get<1>(pt_kernel_signature_->args);
Expand Down Expand Up @@ -1836,77 +1824,39 @@ void OperatorWithKernel::BuildPtenKernelContext(

// calcute the start and end index of the input tensors
size_t start_idx =
(i == 0 ? 0 : pt_kernel_context_->InputRangeAt(i - 1).second);
(i == 0 ? 0 : pt_kernel_context->InputRangeAt(i - 1).second);
size_t end_idx = start_idx + ins_vector.size();
auto current_vector_size = pt_kernel_context_->InputsSize();

// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
auto& input_ptr =
pt_kernel_context_->MutableInputPtrAt(start_idx + offset);
if (input_ptr == nullptr) {
input_ptr = experimental::MakePtenTensorBaseFromVar(
*ins_vector[offset], in_def);
} else {
experimental::ReMakePtenDenseTensorFromVar(
*ins_vector[offset], in_def,
pt_kernel_context_->MutableInputAt<pten::DenseTensor>(start_idx +
offset));
}
} else {
pt_kernel_context_->EmplaceBackInputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(*ins_vector[offset],
in_def));
}
pt_kernel_context->EmplaceBackInputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(*ins_vector[offset], in_def));
}
pt_kernel_context_->AssignInputRange(std::make_pair(start_idx, end_idx), i);
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i);
}

for (size_t i = 0; i < output_names.size(); ++i) {
auto& out_def = output_defs.at(i);
auto& outs_vector = ctx.outputs.at(output_names[i]);

size_t start_idx =
(i == 0 ? 0 : pt_kernel_context_->OutputRangeAt(i - 1).second);
(i == 0 ? 0 : pt_kernel_context->OutputRangeAt(i - 1).second);
size_t end_idx = start_idx + outs_vector.size();
auto current_vector_size = pt_kernel_context_->OutputsSize();

// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
auto* buffer_tensor =
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(start_idx +
offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(outs_vector[offset],
out_def, buffer_tensor);
}
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(outs_vector[offset],
out_def));
}
pt_kernel_context->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(outs_vector[offset],
out_def));
}

// Deal with the case that some outputs are NULL when run the kernel.
// For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
if (outs_vector.empty()) {
if (current_vector_size > start_idx) {
pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr});
}
pt_kernel_context->EmplaceBackOutputWithoutSetRange({nullptr});
end_idx = start_idx + 1;
}

pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx),
i);
pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
}

for (size_t i = 0; i < attr_names.size(); ++i) {
Expand All @@ -1915,11 +1865,11 @@ void OperatorWithKernel::BuildPtenKernelContext(
if (attr_iter != Attrs().end()) { // shape is in the attribute
if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray(
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray(
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
Expand All @@ -1930,10 +1880,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
} else { // shape is in the input
auto& ins_vector = ctx.inputs.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
pt_kernel_context_->EmplaceBackAttr(std::move(
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(*ins_vector.front())));
} else { // ShapeTensorList
pt_kernel_context_->EmplaceBackAttr(std::move(
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(ins_vector)));
}
}
Expand All @@ -1946,11 +1896,11 @@ void OperatorWithKernel::BuildPtenKernelContext(
if (attr_iter != Attrs().end()) { // scalar is in the attribute
auto& attr = Attrs().at(attr_names[i]);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
pt_kernel_context_->EmplaceBackAttr(
pt_kernel_context->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
pt_kernel_context_->EmplaceBackAttr(
pt_kernel_context->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
Expand All @@ -1960,25 +1910,25 @@ void OperatorWithKernel::BuildPtenKernelContext(
}
} else {
auto& ins_vector = ctx.inputs.at(attr_names[i]);
pt_kernel_context_->EmplaceBackAttr(std::move(
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarFromVar(*ins_vector.front())));
}

} else {
// TODO(chenweihang): support other attrs later
auto& attr = Attrs().at(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::DataType))) {
auto data_type = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr)));
pt_kernel_context_->EmplaceBackAttr(data_type);
pt_kernel_context->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) ==
Expand All @@ -1987,7 +1937,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
pt_kernel_context_->EmplaceBackAttr(vector_int64_attr);
pt_kernel_context->EmplaceBackAttr(vector_int64_attr);
}
// TODO(YuanRisheng) Need support vector<int64_t> attr

Expand All @@ -2001,20 +1951,16 @@ void OperatorWithKernel::BuildPtenKernelContext(
}
}

void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const {
// auto& input_names = std::get<0>(pt_kernel_signature_->args);
// auto& attr_names = std::get<1>(pt_kernel_signature_->args);
void OperatorWithKernel::WriteBackToOutputs(
RuntimeContext* ctx, pten::KernelContext* pt_kernel_context) const {
auto& output_names = std::get<2>(pt_kernel_signature_->args);

// pt_kernel_context_

for (size_t i = 0; i < output_names.size(); ++i) {
auto& outs_vector = ctx->outputs.at(output_names[i]);

auto& range_pair = pt_kernel_context_->OutputRangeAt(i);
auto pten_outs =
pt_kernel_context_->MutableOutputBetween<pten::DenseTensor>(
range_pair.first, range_pair.second);
auto& range_pair = pt_kernel_context->OutputRangeAt(i);
auto pten_outs = pt_kernel_context->MutableOutputBetween<pten::DenseTensor>(
range_pair.first, range_pair.second);

for (size_t j = 0; j < pten_outs.size(); ++j) {
if (pten_outs[j]) {
Expand Down
13 changes: 4 additions & 9 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,16 +589,14 @@ class OperatorWithKernel : public OperatorBase {
void ChoosePtenKernel(const ExecutionContext& ctx) const;

void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx) const;
platform::DeviceContext* dev_ctx,
pten::KernelContext* pt_kernel_context) const;

void WriteBackToOutputs(RuntimeContext* ctx) const;
void WriteBackToOutputs(RuntimeContext* ctx,
pten::KernelContext* pt_kernel_context) const;

pten::Kernel* PtenKernel() const { return pt_kernel_.get(); }

pten::KernelContext* PtenKernelContext() const {
return pt_kernel_context_.get();
}

const OpKernelType* kernel_type() const { return kernel_type_.get(); }

private:
Expand Down Expand Up @@ -657,9 +655,6 @@ class OperatorWithKernel : public OperatorBase {
mutable bool run_pten_kernel_ = false;
mutable std::unique_ptr<KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<pten::Kernel> pt_kernel_;
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
mutable std::unique_ptr<pten::KernelContext> pt_kernel_context_;
};

extern bool OpSupportGPU(const std::string& op_type);
Expand Down
15 changes: 5 additions & 10 deletions paddle/fluid/imperative/layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,6 @@ void VarBase::_CopyGradientFrom(const VarBase& src) {
}
}

pten::KernelContext OpBase::pt_kernel_context_;

void OpBase::SetType(const std::string& type) {
op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
}
Expand All @@ -426,8 +424,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place,
pten::KernelContext* pt_kernel_context) {
const platform::Place& place) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(
op_kernel, platform::errors::PermissionDenied(
Expand Down Expand Up @@ -468,8 +465,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
* after the execution of op, but the original input is directly
* overwritten in the previous dynamic graph implemention.
*/
auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs,
default_attrs, pt_kernel_context);
auto prepared_op =
PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs);
auto tmp_ins_ptr =
PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type());
if (tmp_ins_ptr == nullptr) {
Expand Down Expand Up @@ -497,8 +494,7 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place,
&pt_kernel_context_);
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place);
}

void OpBase::Run(const framework::OperatorBase& op,
Expand All @@ -507,8 +503,7 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place,
&pt_kernel_context_);
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place);
}

void ClearNoNeedBufferInputs(OpBase* op) {
Expand Down
Loading

0 comments on commit 35d2b71

Please sign in to comment.