Skip to content

Commit

Permalink
polish kernel factory and kernel registry
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Oct 21, 2021
1 parent 76a588e commit fb224ab
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 71 deletions.
25 changes: 4 additions & 21 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1080,20 +1080,6 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this->InferShape(&infer_shape_ctx);
}

static std::string RuntimeContextDebugString(const RuntimeContext& ctx) {
std::stringstream ss;
ss << "RuntimeContext(Inputs: ";
for (auto& var_pair : ctx.inputs) {
ss << var_pair.first << ", ";
}
ss << "Outputs: ";
for (auto& var_pair : ctx.outputs) {
ss << var_pair.first << ", ";
}
ss << ")";
return ss.str();
}

void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
// To reduce the elapsed time of HasAttr, we use bool variable to record the
Expand Down Expand Up @@ -1144,7 +1130,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second
// phase
if (FLAGS_run_pt_kernel &&
pten::KernelFactory::Instance().ContainsKernel(type_.c_str())) {
pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) {
ChoosePtenKernel(exe_ctx);
}
Expand Down Expand Up @@ -1651,10 +1637,9 @@ void OperatorWithKernel::ParseInputDataType(
if (t != nullptr) {
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument(
"The Tensor in the %s Op's Input Variable %s(%s) is "
"not initialized.",
Type(), name, Inputs().at(name).at(i)));
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
"contains uninitialized Tensor.",
Type(), name));
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -1789,8 +1774,6 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(

pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const {
VLOG(1) << RuntimeContextDebugString(ctx);

// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;

if (FLAGS_run_pt_kernel &&
pten::KernelFactory::Instance().ContainsKernel(op.Type().c_str())) {
pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
auto pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx);

VLOG(1) << framework::KernelSignatureToString(pt_kernel_signature);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ GenerateOpFunctions() {
// since only OperatorWithKernel can run in dygraph mode.
// if the pten lib contains op kernel, we still generate ops method
if (!all_kernels.count(op_type) &&
!pten::KernelFactory::Instance().ContainsKernel(op_type.c_str())) {
!pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) {
continue;
}

Expand Down
18 changes: 13 additions & 5 deletions paddle/pten/core/kernel_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,24 @@

namespace pten {

uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
uint32_t hash_value = 0;
// |----31-20------|---19-12---|---11-8----|---7-0---|
// | For extension | DataType | DataLayout | Backend |
hash_value |= static_cast<uint8_t>(key.backend());
hash_value |=
(static_cast<uint8_t>(key.layout()) << KernelKey::kBackendBitLength);
hash_value |=
(static_cast<uint16_t>(key.dtype())
<< (KernelKey::kBackendBitLength + KernelKey::kDataTypeBitLength));
return hash_value;
}

KernelFactory& KernelFactory::Instance() {
static KernelFactory g_op_kernel_factory;
return g_op_kernel_factory;
}

bool KernelFactory::ContainsKernel(const char* kernel_name) const {
auto iter = kernels_.find(KernelName(kernel_name, ""));
return (iter != kernels_.end());
}

Kernel KernelFactory::SelectKernel(const KernelName& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
Expand Down
77 changes: 34 additions & 43 deletions paddle/pten/core/kernel_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include "paddle/pten/common/backend.h"
Expand All @@ -37,10 +38,10 @@ using DataLayout = paddle::experimental::DataLayout;
/**
* [ Naming considerations ]
*
* The tensor Compute library contains many kernels, and the computation
* The tensor operation library contains many kernels, and the computation
* in each specific scenario is represented by an kernel.
*
* We directly named it `Kernel` instead of `Kernel`, the tensor Compute
* We directly named it `Kernel` instead of `Kernel`, the tensor operation
* library here and fluid are independent, avoiding developers from
* misunderstanding the relationship between the two concepts.
*/
Expand All @@ -52,10 +53,7 @@ using KernelFn = void (*)(KernelContext* ctx);
class KernelName final {
public:
KernelName(std::string name, std::string overload_name)
: name_(std::move(name)), overload_name_(std::move(overload_name)) {
hash_value_ = std::hash<std::string>()(name_) ^
(std::hash<std::string>()(overload_name_) << 1);
}
: name_(std::move(name)), overload_name_(std::move(overload_name)) {}

KernelName(const std::string& kernel_name) {
ParseNameAndOverloadNameFromString(kernel_name);
Expand All @@ -68,24 +66,26 @@ class KernelName final {

const std::string& name() const { return name_; }
const std::string& overload_name() const { return overload_name_; }
size_t hash_value() const { return hash_value_; }

struct Hash {
size_t operator()(const KernelName& kernel_name) const {
return kernel_name.hash_value();
return std::hash<std::string>()(kernel_name.name()) ^
(std::hash<std::string>()(kernel_name.overload_name()) << 1);
}
};

size_t hash_value() const { return Hash()(*this); }

bool operator<(const KernelName& kernel_name) const {
return hash_value_ < kernel_name.hash_value();
return hash_value() < kernel_name.hash_value();
}

bool operator==(const KernelName& kernel_name) const {
return hash_value_ == kernel_name.hash_value();
return hash_value() == kernel_name.hash_value();
}

bool operator!=(const KernelName& kernel_name) const {
return hash_value_ != kernel_name.hash_value();
return hash_value() != kernel_name.hash_value();
}

private:
Expand All @@ -98,57 +98,45 @@ class KernelName final {
name_ = kernel_name.substr(0, pos);
overload_name_ = kernel_name.substr(pos + 1, kernel_name.size());
}
hash_value_ = std::hash<std::string>()(name_) ^
(std::hash<std::string>()(overload_name_) << 1);
}

// The members cannot be modified except by constructing,
// because the hash value need to be re calculated
// TODO(chenweihang): use string_view later?
// TODO(chenweihang): use string_view to improve performance later
std::string name_;
std::string overload_name_;
// Avoid calculating Hash value at runtime
size_t hash_value_;
};

class KernelKey {
public:
KernelKey() = default;

KernelKey(Backend backend, DataLayout layout, DataType dtype)
: backend_(backend), layout_(layout), dtype_(dtype) {
// |----31-20------|---19-12---|---11-8----|---7-0---|
// | For extension | DataType | DataLayout | Backend |

hash_value_ = 0;
hash_value_ |= static_cast<uint8_t>(backend_);
hash_value_ |= (static_cast<uint8_t>(layout_) << kBackendBitLength);
hash_value_ |= (static_cast<uint16_t>(dtype_)
<< (kBackendBitLength + kDataTypeBitLength));
}
: backend_(backend), layout_(layout), dtype_(dtype) {}

Backend backend() const { return backend_; }
DataLayout layout() const { return layout_; }
DataType dtype() const { return dtype_; }

uint32_t hash_value() const { return hash_value_; }
struct Hash {
// Note: Now the number of bits we need does not exceed 32 bits, so there is
// no need to use 64 bits. If needed in the future, it can be expanded,
// but now we don’t over-design.
uint32_t operator()(const KernelKey& key) const;
};

uint32_t hash_value() const { return Hash()(*this); }

bool operator<(const KernelKey& key) const {
return hash_value_ < key.hash_value();
return hash_value() < key.hash_value();
}

bool operator==(const KernelKey& key) const {
return hash_value_ == key.hash_value();
return hash_value() == key.hash_value();
}

bool operator!=(const KernelKey& key) const {
return hash_value_ != key.hash_value();
return hash_value() != key.hash_value();
}

struct Hash {
uint32_t operator()(const KernelKey& key) const { return key.hash_value(); }
};

private:
// In total should be smaller than 32.
constexpr static int kBackendBitLength = 8;
Expand All @@ -158,12 +146,6 @@ class KernelKey {
Backend backend_{Backend::UNDEFINED};
DataLayout layout_{DataLayout::UNDEFINED};
DataType dtype_{DataType::UNDEFINED};

// Avoid calculating Hash value at runtime.
// Note: Now the number of bits we need does not exceed 32 bits, so there is
// no need to use 64 bits. If needed in the future, it can be expanded,
// but now we don’t over-design.
uint32_t hash_value_;
};

// TODO(chenweihang): how deal with vector<Param>?
Expand Down Expand Up @@ -282,7 +264,13 @@ class KernelFactory {

KernelMap& kernels() { return kernels_; }

bool ContainsKernel(const char* name) const;
void InsertCompatibleOpType(const std::string& op_type) {
compatible_op_types_.insert(op_type);
}

bool HasCompatiblePtenKernel(const std::string& op_type) const {
return compatible_op_types_.count(op_type) > 0;
}

const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name,
const KernelKey& kernel_key) const;
Expand All @@ -299,6 +287,9 @@ class KernelFactory {
KernelFactory() = default;

KernelMap kernels_;
// Used to be compatible with the original execution system and
// quickly confirm whether the new kernel can be called
std::unordered_set<std::string> compatible_op_types_;
};

/** operator << overload **/
Expand Down
1 change: 1 addition & 0 deletions paddle/pten/core/kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ struct KernelRegistrar {
args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(&kernel);

KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name());
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
}
};
Expand Down

1 comment on commit fb224ab

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on fb224ab Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #34425 Commit ID: fb224ab contains failed CI.

🔹 Failed: PR-CI-APPROVAL

approve_failed
2021-10-21 22:19:00 正在保存至: “bk.txt”
2021-10-21 22:19:00 0K 100% 3.16M=0s
2021-10-21 22:19:00 2021-10-21 22:19:00 (3.16 MB/s) - 已保存 “bk.txt” [5/5])
2021-10-21 22:19:08 ****************
2021-10-21 22:19:08 0. You must have one RD (lanxianghit (Recommend), phlrain or luotao1) approval for changing the FLAGS, which manages the environment variables.
2021-10-21 22:19:08 1. You must have Dianhai approval for change 20+ files or add than 1000+ lines of content.
2021-10-21 22:19:08 2. You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1) approval for paddle/fluid/framework/operator.h, which manages the underlying code for fluid.
2021-10-21 22:19:08 3. You must have one RD (zhiqiu (Recommend) , phlrain) approval for the changes of paddle/fluid/pybind/op_function_generator.cc, which manages the logic of automatic generating op functions for dygraph.
2021-10-21 22:19:08 4. You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1) approval for the usage of const_cast.
2021-10-21 22:19:08 5. You must have one RD (Avin0323(Recommend) or zhouwei25 or wanghuancoder or luotao1) approval for modifying unity_build_rule.cmake which the rules of Unity Build.
2021-10-21 22:19:08 There are 6 approved errors.
2021-10-21 22:19:08 ****************
2021-10-21 22:19:08 + EXCODE=6
2021-10-21 22:19:08 + echo 'EXCODE: 6'
2021-10-21 22:19:08 EXCODE: 6
2021-10-21 22:19:08 + echo 'ipipe_log_param_EXCODE: 6'
2021-10-21 22:19:08 ipipe_log_param_EXCODE: 6
2021-10-21 22:19:08 + exit 6

🔹 Failed: PR-CI-Windows-OPENBLAS

test_failed
2021-10-21 23:00:38 The following tests FAILED:
2021-10-21 23:00:38 55 - operator_test (Failed)
2021-10-21 23:00:38 55 - operator_test (Failed)
2021-10-21 23:00:38 55 - operator_test (Failed)
2021-10-21 23:00:38 C:\home\workspace\Paddle\build>goto:eof
2021-10-21 23:00:38 C:\home\workspace\Paddle\build>for /F %# in ('wmic os get localdatetime|findstr 20') do set end=%#
2021-10-21 23:00:38 C:\home\workspace\Paddle\build>set end=20211021230038.502000+480
2021-10-21 23:00:38 C:\home\workspace\Paddle\build>set end=1021230038
2021-10-21 23:00:38 C:\home\workspace\Paddle\build>call :timestamp "1021224342" "1021230038" "1 card TestCases Total"
2021-10-21 23:00:38 C:\home\workspace\Paddle\build>setlocal enabledelayedexpansion
2021-10-21 23:00:38 1896222
2021-10-21 23:00:38 "Windows 1 card TestCases Total Time: 1016s"
2021-10-21 23:00:38 ipipe_log_param_Windows_1_card_TestCases_Total_Time: 1016s
2021-10-21 23:00:38 1896222
2021-10-21 23:00:38 "Windows TestCases Total Time: 1016s"
2021-10-21 23:00:38 ipipe_log_param_Windows_TestCases_Total_Time: 1016s
2021-10-21 23:00:38 Running unit tests failed, will exit
2021-10-21 23:00:38 EXCODE: 8

🔹 Failed: PR-CI-Mac-Python3

Unknown Failed
2021-10-22 00:39:33 Ran 1 test in 206.308s
2021-10-22 00:39:33 FAILED (errors=1)
2021-10-22 00:39:33 0% tests passed, 5 tests failed out of 5
2021-10-22 00:39:33 Total Test time (real) = 211.16 sec
2021-10-22 00:39:33 The following tests FAILED:
2021-10-22 00:39:33 56 - operator_test (Failed)
2021-10-22 00:39:33 1050 - test_build_strategy (Failed)
2021-10-22 00:39:33 1095 - test_se_resnet (Failed)
2021-10-22 00:39:33 1151 - test_image_classification (Failed)
2021-10-22 00:39:33 1196 - test_weight_quantization_mobilenetv1 (Failed)
2021-10-22 00:39:33 Errors while running CTest
2021-10-22 00:39:33 + EXCODE=1
2021-10-22 00:39:33 + echo 'EXCODE: 1'
2021-10-22 00:39:33 EXCODE: 1
2021-10-22 00:39:33 + echo 'ipipe_log_param_EXCODE: 1'
2021-10-22 00:39:33 ipipe_log_param_EXCODE: 1
2021-10-22 00:39:33 + '[' 1 -eq 0 ']'
2021-10-22 00:39:33 + set +x
2021-10-22 00:39:33 + exit 1

🔹 Failed: PR-CI-OP-benchmark

Unknown Failed
2021-10-22 06:12:38 + echo '[tools/test_ci_op_benchmark.sh:271] [ERROR] Missing test script of "mean"(paddle/fluid/operators/mean_op.cu) in benchmark.'
2021-10-22 06:12:38 [tools/test_ci_op_benchmark.sh:271] [ERROR] Missing test script of "mean"(paddle/fluid/operators/mean_op.cu) in benchmark.
2021-10-22 06:12:38 + for op_name in '${!CHANGE_OP_MAP[@]}'
2021-10-22 06:12:38 + '[' -z '' ']'
2021-10-22 06:12:38 + exit_code=8
2021-10-22 06:12:38 + LOG '[ERROR] Missing test script of "fill_any_like"(paddle/fluid/operators/fill_any_like_op.cu) in benchmark.'
2021-10-22 06:12:38 + echo '[tools/test_ci_op_benchmark.sh:271] [ERROR] Missing test script of "fill_any_like"(paddle/fluid/operators/fill_any_like_op.cu) in benchmark.'
2021-10-22 06:12:38 [tools/test_ci_op_benchmark.sh:271] [ERROR] Missing test script of "fill_any_like"(paddle/fluid/operators/fill_any_like_op.cu) in benchmark.
2021-10-22 06:12:38 + for op_name in '${!CHANGE_OP_MAP[@]}'
2021-10-22 06:12:38 + '[' -z matmul,matmul,matmul.json,True ']'
2021-10-22 06:12:38 + '[' 8 -ne 0 ']'
2021-10-22 06:12:38 + LOG '[INFO] See https://github.com/PaddlePaddle/Paddle/wiki/PR-CI-OP-benchmark-Manual for details.'
2021-10-22 06:12:38 + echo '[tools/test_ci_op_benchmark.sh:275] [INFO] See https://github.com/PaddlePaddle/Paddle/wiki/PR-CI-OP-benchmark-Manual for details.'
2021-10-22 06:12:38 [tools/test_ci_op_benchmark.sh:275] [INFO] See https://github.com/PaddlePaddle/Paddle/wiki/PR-CI-OP-benchmark-Manual for details.
2021-10-22 06:12:38 + LOG '[INFO] Or you can apply for one RD (Avin0323(Recommend), Xreki, luotao1) approval to pass this PR.'
2021-10-22 06:12:38 + echo '[tools/test_ci_op_benchmark.sh:276] [INFO] Or you can apply for one RD (Avin0323(Recommend), Xreki, luotao1) approval to pass this PR.'
2021-10-22 06:12:38 [tools/test_ci_op_benchmark.sh:276] [INFO] Or you can apply for one RD (Avin0323(Recommend), Xreki, luotao1) approval to pass this PR.
2021-10-22 06:12:38 + exit 8
2021-10-22 06:12:38 {build code state=8}

🔹 Failed: PR-CI-Py3

test_failed
2021-10-22 07:16:53 The following tests FAILED:
2021-10-22 07:16:53 81 - operator_test (Failed)
2021-10-22 07:16:53 + EXCODE=8
2021-10-22 07:16:53 + echo 'EXCODE: 8'
2021-10-22 07:16:53 EXCODE: 8
2021-10-22 07:16:53 + echo 'ipipe_log_param_EXCODE: 8'
2021-10-22 07:16:53 ipipe_log_param_EXCODE: 8
2021-10-22 07:16:53 + [[ 8 -eq 0 ]]
2021-10-22 07:16:53 + set +x
2021-10-22 07:16:53 Sorry, some tests failed.
2021-10-22 07:16:53 + exit 8
2021-10-22 07:16:53 {build code state=8}
2021-10-22 07:17:03 kill agent BUILD_CODE_FAIL

🔹 Failed: PR-CI-Coverage

test_failed
2021-10-22 07:43:25 The following tests FAILED:
2021-10-22 07:43:25 81 - operator_test (Failed)
2021-10-22 07:43:25 + EXCODE=8
2021-10-22 07:43:25 + echo 8
2021-10-22 07:43:25 8
2021-10-22 07:43:25 + echo 'ipipe_log_param_EXCODE: 8'
2021-10-22 07:43:25 ipipe_log_param_EXCODE: 8
2021-10-22 07:43:25 + '[' 8 -ne 0 ']'
2021-10-22 07:43:25 + '[' 8 -ne 9 ']'
2021-10-22 07:43:25 + exit 8
2021-10-22 07:43:25 {build code state=8}

Please sign in to comment.