Skip to content

Commit

Permalink
change all Pt to Pten
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Oct 20, 2021
1 parent ab8db2d commit d3674e9
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 70 deletions.
22 changes: 11 additions & 11 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (FLAGS_run_pt_kernel &&
pten::KernelFactory::Instance().ContainsKernel(type_.c_str())) {
if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) {
ChoosePtKernel(exe_ctx);
ChoosePtenKernel(exe_ctx);
}
run_pt_kernel_ = pt_kernel_->IsValid();
}
Expand Down Expand Up @@ -1192,7 +1192,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp);
if (run_pt_kernel_) {
auto op_kernel_ctx = BuildPtKernelContext(*runtime_ctx, *dev_ctx);
auto op_kernel_ctx = BuildPtenKernelContext(*runtime_ctx, *dev_ctx);
(*pt_kernel_)(&op_kernel_ctx);
} else {
(*kernel_func_)(
Expand Down Expand Up @@ -1282,26 +1282,26 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
return expected_kernel_key;
}

void OperatorWithKernel::ChoosePtKernel(const ExecutionContext& ctx) const {
void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const {
pt_kernel_signature_.reset(
new KernelSignature(this->GetExpectedPtKernelArgs(ctx)));
new KernelSignature(this->GetExpectedPtenKernelArgs(ctx)));

VLOG(1) << KernelSignatureToString(*pt_kernel_signature_.get());

kernel_type_.reset(new OpKernelType(InnerGetExpectedKernelType(ctx)));

auto pt_kernel_name = pten::KernelName(pt_kernel_signature_->first);
auto pt_kernel_key = TransOpKernelTypeToPtKernelKey(*kernel_type_.get());
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get());
pt_kernel_.reset(
new pten::Kernel(pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key)));

if (pt_kernel_->IsValid()) {
VLOG(1) << "Static mode ChoosePtKernel - kernel name: " << pt_kernel_name
VLOG(1) << "Static mode ChoosePtenKernel - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_kernel_key
<< " | kernel: " << *pt_kernel_;
} else {
VLOG(1) << "Static mode ChoosePtKernel - kernel `" << pt_kernel_name
VLOG(1) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name
<< "` not found.";
}
}
Expand Down Expand Up @@ -1774,7 +1774,7 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
tensor.layout());
}

KernelSignature OperatorWithKernel::GetExpectedPtKernelArgs(
KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const {
if (KernelSignatureMap::Instance().Has(Type())) {
return *(KernelSignatureMap::Instance().GetNullable(Type()));
Expand All @@ -1786,7 +1786,7 @@ KernelSignature OperatorWithKernel::GetExpectedPtKernelArgs(
}
}

pten::KernelContext OperatorWithKernel::BuildPtKernelContext(
pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const {
VLOG(1) << RuntimeContextDebugString(ctx);

Expand Down Expand Up @@ -1834,7 +1834,7 @@ pten::KernelContext OperatorWithKernel::BuildPtKernelContext(
std::vector<std::shared_ptr<pten::TensorBase>> tmp_inputs;

for (auto var : ins_vector) {
auto pt_in = framework::InputVariableToPtTensor(*var, in_def);
auto pt_in = framework::InputVariableToPtenTensor(*var, in_def);
tmp_inputs.emplace_back(pt_in);
}
op_kernel_ctx.EmplaceBackInputs(tmp_inputs);
Expand All @@ -1846,7 +1846,7 @@ pten::KernelContext OperatorWithKernel::BuildPtKernelContext(

std::vector<std::shared_ptr<pten::TensorBase>> tmp_outputs;
for (auto var : outs_vector) {
auto pt_out = framework::OutputVariableToPtTensor(var, out_def);
auto pt_out = framework::OutputVariableToPtenTensor(var, out_def);
tmp_outputs.emplace_back(pt_out);
}
op_kernel_ctx.EmplaceBackOutputs(tmp_outputs);
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,9 @@ class OperatorWithKernel : public OperatorBase {
* output arguments registered in the original OpMaker do not match in some
* cases, so we use map to record the arguments required by the kernel.
* When selecting Kernel during Op execution, select the arguments of the
* original Op according to the GetExpectedPtKernelArgs returned arguments.
* original Op according to the GetExpectedPtenKernelArgs returned arguments.
*/
virtual KernelSignature GetExpectedPtKernelArgs(
virtual KernelSignature GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const;

private:
Expand Down Expand Up @@ -583,9 +583,9 @@ class OperatorWithKernel : public OperatorBase {
const std::string& name) const;

/* member functions for adapting to pten lib */
void ChoosePtKernel(const ExecutionContext& ctx) const;
void ChoosePtenKernel(const ExecutionContext& ctx) const;

pten::KernelContext BuildPtKernelContext(
pten::KernelContext BuildPtenKernelContext(
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const;

protected:
Expand Down
41 changes: 22 additions & 19 deletions paddle/fluid/framework/pten_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,17 @@ std::shared_ptr<pten::DenseTensor> MakeTensorImpl<pten::DenseTensor>(
const LoDTensor& tensor, const platform::Place& place,
proto::VarType::Type type) {
return MakeTensorImpl<pten::DenseTensor, LoDTensor>(
tensor, pten::TransToPtBackend(place), pten::TransToPtDataType(type),
pten::TransToPtDataLayout(tensor.layout()));
tensor, pten::TransToPtenBackend(place), pten::TransToPtenDataType(type),
pten::TransToPtenDataLayout(tensor.layout()));
}

template <>
std::shared_ptr<pten::DenseTensor> MakeTensorImpl<pten::DenseTensor>(
const Tensor& tensor, const platform::Place& place,
proto::VarType::Type type) {
return MakeTensorImpl<pten::DenseTensor, Tensor>(
tensor, pten::TransToPtBackend(place), pten::TransToPtDataType(type),
pten::TransToPtDataLayout(tensor.layout()));
tensor, pten::TransToPtenBackend(place), pten::TransToPtenDataType(type),
pten::TransToPtenDataLayout(tensor.layout()));
}

template <>
Expand All @@ -93,7 +93,7 @@ void ShareTensorImpl<pten::DenseTensor>(pten::DenseTensor* tensor_impl,
pten::TransToProtoVarType(tensor_impl->data_type()));
}

std::shared_ptr<pten::TensorBase> InputVariableToPtTensor(
std::shared_ptr<pten::TensorBase> InputVariableToPtenTensor(
const framework::Variable& variable, const pten::TensorArgDef& arg_def) {
auto expected_place = pten::TransToFluidPlace(arg_def.backend);

Expand Down Expand Up @@ -138,7 +138,7 @@ std::shared_ptr<pten::TensorBase> InputVariableToPtTensor(
return nullptr;
}

std::shared_ptr<pten::TensorBase> OutputVariableToPtTensor(
std::shared_ptr<pten::TensorBase> OutputVariableToPtenTensor(
framework::Variable* variable, const pten::TensorArgDef& arg_def) {
// mutable_data before run kernel, to avoid share output form
// KernelContext to original tensor
Expand Down Expand Up @@ -170,7 +170,8 @@ std::shared_ptr<pten::TensorBase> OutputVariableToPtTensor(
return nullptr;
}

OpKernelType TransPtKernelKeyToOpKernelType(const pten::KernelKey& kernel_key) {
OpKernelType TransPtenKernelKeyToOpKernelType(
const pten::KernelKey& kernel_key) {
proto::VarType::Type data_type =
pten::TransToProtoVarType(kernel_key.dtype());
platform::Place place = pten::TransToFluidPlace(kernel_key.backend());
Expand All @@ -187,9 +188,9 @@ OpKernelType TransPtKernelKeyToOpKernelType(const pten::KernelKey& kernel_key) {
return OpKernelType(data_type, place, data_layout, library_type);
}

pten::KernelKey TransOpKernelTypeToPtKernelKey(
pten::KernelKey TransOpKernelTypeToPtenKernelKey(
const OpKernelType& kernel_type) {
pten::Backend backend = pten::TransToPtBackend(kernel_type.place_);
pten::Backend backend = pten::TransToPtenBackend(kernel_type.place_);
if (kernel_type.library_type_ == LibraryType::kMKLDNN) {
backend = pten::Backend::MKLDNN;
} else if (kernel_type.library_type_ == LibraryType::kCUDNN) {
Expand All @@ -198,9 +199,9 @@ pten::KernelKey TransOpKernelTypeToPtKernelKey(
// do
}
paddle::experimental::DataLayout layout =
pten::TransToPtDataLayout(kernel_type.data_layout_);
pten::TransToPtenDataLayout(kernel_type.data_layout_);
paddle::experimental::DataType dtype =
pten::TransToPtDataType(kernel_type.data_type_);
pten::TransToPtenDataType(kernel_type.data_type_);
return pten::KernelKey(backend, layout, dtype);
}

Expand All @@ -215,16 +216,17 @@ KernelArgsNameMakerByOpProto::GetInputArgsNames() {
auto& in = op_proto_->inputs()[i];
auto& in_name = in.name();
if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) {
VLOG(1) << "Parse PtKernel input: skip extra & quant input - " << in_name;
VLOG(1) << "Parse PtenKernel input: skip extra & quant input - "
<< in_name;
continue;
}
// If contains dispensable input, we should override the
// GetExpectedPtKernelArgs method self
// GetExpectedPtenKernelArgs method self
if (in.has_dispensable() && in.dispensable()) {
VLOG(1) << "Parse PtKernel input: skip dispensable input - " << in_name;
VLOG(1) << "Parse PtenKernel input: skip dispensable input - " << in_name;
continue;
}
VLOG(1) << "Parse PtKernel input: " << in_name;
VLOG(1) << "Parse PtenKernel input: " << in_name;
input_names_.emplace_back(in_name);
}
return input_names_;
Expand All @@ -236,7 +238,7 @@ KernelArgsNameMakerByOpProto::GetOutputArgsNames() {
auto& out = op_proto_->outputs()[i];
auto& out_name = out.name();
// TODO(chenweihang): outputs also need skip some cases
VLOG(1) << "Parse PtKernel output: " << out_name;
VLOG(1) << "Parse PtenKernel output: " << out_name;
output_names_.emplace_back(out_name);
}
return output_names_;
Expand All @@ -250,16 +252,17 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
if (attr_name == "use_mkldnn" || attr_name == "op_role" ||
attr_name == "op_role_var" || attr_name == "op_namescope" ||
attr_name == "op_callstack" || attr_name == "op_device") {
VLOG(1) << "Parse PtKernel attribute: skip needless attr - " << attr_name;
VLOG(1) << "Parse PtenKernel attribute: skip needless attr - "
<< attr_name;
continue;
}
if ((attr.has_extra() && attr.extra()) ||
(attr.has_quant() && attr.quant())) {
VLOG(1) << "Parse PtKernel attribute: skip extra & quant attr - "
VLOG(1) << "Parse PtenKernel attribute: skip extra & quant attr - "
<< attr_name;
continue;
}
VLOG(1) << "Parse PtKernel attribute: " << attr_name;
VLOG(1) << "Parse PtenKernel attribute: " << attr_name;
attr_names_.emplace_back(attr_name);
}

Expand Down
38 changes: 20 additions & 18 deletions paddle/fluid/framework/pten_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,39 @@ namespace framework {

/* tensor translate */

template <typename PtTensorImplT, typename VariableT>
std::shared_ptr<PtTensorImplT> MakeTensorImpl(
template <typename PtenTensorImplT, typename VariableT>
std::shared_ptr<PtenTensorImplT> MakeTensorImpl(
const VariableT& tensor, pten::Backend backend,
paddle::experimental::DataType dtype,
paddle::experimental::DataLayout layout);

template <typename PtTensorImplT>
std::shared_ptr<PtTensorImplT> MakeTensorImpl(const LoDTensor& tensor,
const platform::Place& place,
proto::VarType::Type type);
template <typename PtenTensorImplT>
std::shared_ptr<PtenTensorImplT> MakeTensorImpl(const LoDTensor& tensor,
const platform::Place& place,
proto::VarType::Type type);

template <typename PtTensorImplT>
std::shared_ptr<PtTensorImplT> MakeTensorImpl(const Tensor& tensor,
const platform::Place& place,
proto::VarType::Type type);
template <typename PtenTensorImplT>
std::shared_ptr<PtenTensorImplT> MakeTensorImpl(const Tensor& tensor,
const platform::Place& place,
proto::VarType::Type type);

template <typename PtTensorImplT>
void ShareTensorImpl(PtTensorImplT* tensor_impl, LoDTensor* out);
template <typename PtenTensorImplT>
void ShareTensorImpl(PtenTensorImplT* tensor_impl, LoDTensor* out);

template <typename PtTensorImplT>
void ShareTensorImpl(PtTensorImplT* tensor_impl, Tensor* out);
template <typename PtenTensorImplT>
void ShareTensorImpl(PtenTensorImplT* tensor_impl, Tensor* out);

std::shared_ptr<pten::TensorBase> InputVariableToPtTensor(
std::shared_ptr<pten::TensorBase> InputVariableToPtenTensor(
const framework::Variable& variable, const pten::TensorArgDef& arg_def);
std::shared_ptr<pten::TensorBase> OutputVariableToPtTensor(
std::shared_ptr<pten::TensorBase> OutputVariableToPtenTensor(
framework::Variable* variable, const pten::TensorArgDef& arg_def);

/* Kernel Key translate */

OpKernelType TransPtKernelKeyToOpKernelType(const pten::KernelKey& kernel_key);
pten::KernelKey TransOpKernelTypeToPtKernelKey(const OpKernelType& kernel_type);
OpKernelType TransPtenKernelKeyToOpKernelType(
const pten::KernelKey& kernel_key);
pten::KernelKey TransOpKernelTypeToPtenKernelKey(
const OpKernelType& kernel_type);

/* Kernel Args parse */

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/pten_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ TEST(TcmptUtils, MakeTensor) {
ASSERT_EQ(dense_x->data_type(), pten::DataType::FLOAT32);
}

TEST(TcmptUtils, VarToPtTensor) {
TEST(TcmptUtils, VarToPtenTensor) {
// 1. create Variable
Variable v;
auto selected_rows = v.GetMutable<SelectedRows>();
Expand All @@ -57,7 +57,7 @@ TEST(TcmptUtils, VarToPtTensor) {
auto tensor_def = pten::TensorArgDef(expect_backend, pten::DataLayout::NCHW,
pten::DataType::INT32);
// 2. test API
auto tensor_x = InputVariableToPtTensor(v, tensor_def);
auto tensor_x = InputVariableToPtenTensor(v, tensor_def);
// 3. check result
ASSERT_EQ(tensor_x->backend(), expect_backend);
ASSERT_EQ(tensor_x->data_type(), pten::DataType::INT32);
Expand Down
14 changes: 7 additions & 7 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,12 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,

if (FLAGS_run_pt_kernel &&
pten::KernelFactory::Instance().ContainsKernel(op.Type().c_str())) {
auto pt_kernel_signature = op.GetExpectedPtKernelArgs(dygraph_exe_ctx);
auto pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx);

VLOG(1) << framework::KernelSignatureToString(pt_kernel_signature);

auto pt_kernel_name = pten::KernelName(pt_kernel_signature.first);
auto pt_kernel_key = TransOpKernelTypeToPtKernelKey(expected_kernel_key);
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key);
auto pt_kernel = pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key);

Expand All @@ -171,7 +171,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, dev_ctx);
} else {
VLOG(1) << "Dynamic mode ChoosePtKernel - kernel `" << pt_kernel_name
VLOG(1) << "Dynamic mode ChoosePtenKernel - kernel `" << pt_kernel_name
<< "` not found.";
}
}
Expand Down Expand Up @@ -243,7 +243,7 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
}

template <typename VarType>
static pten::KernelContext BuildDygraphPtKernelContext(
static pten::KernelContext BuildDygraphPtenKernelContext(
const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
Expand Down Expand Up @@ -292,7 +292,7 @@ static pten::KernelContext BuildDygraphPtKernelContext(
for (auto var : ins_vector) {
const auto& variable = var->Var();

auto pt_in = framework::InputVariableToPtTensor(variable, in_def);
auto pt_in = framework::InputVariableToPtenTensor(variable, in_def);
tmp_inputs.emplace_back(pt_in);
}
op_kernel_ctx.EmplaceBackInputs(tmp_inputs);
Expand All @@ -306,7 +306,7 @@ static pten::KernelContext BuildDygraphPtKernelContext(
for (auto var : outs_vector) {
auto* variable = var->MutableVar();

auto pt_out = framework::OutputVariableToPtTensor(variable, out_def);
auto pt_out = framework::OutputVariableToPtenTensor(variable, out_def);
tmp_outputs.emplace_back(pt_out);
}
op_kernel_ctx.EmplaceBackOutputs(tmp_outputs);
Expand Down Expand Up @@ -401,7 +401,7 @@ static void PreparedOpRunPtImpl(
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);

auto op_kernel_ctx = BuildDygraphPtKernelContext<VarType>(
auto op_kernel_ctx = BuildDygraphPtenKernelContext<VarType>(
pt_kernel_signature, pt_kernel, ins, outs, attrs, default_attrs,
*dev_ctx);

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/fill_any_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class FillAnyLikeOp : public framework::OperatorWithKernel {
tensor.layout());
}

framework::KernelSignature GetExpectedPtKernelArgs(
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return std::make_pair(
"fill_any_like",
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/scale_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ScaleOp : public framework::OperatorWithKernel {
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

framework::KernelSignature GetExpectedPtKernelArgs(
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (ctx.HasInput("ScaleTensor")) {
return std::make_pair(
Expand Down
Loading

0 comments on commit d3674e9

Please sign in to comment.